diff options
| -rw-r--r-- | CMakeLists.txt | 9 | ||||
| -rw-r--r-- | COPYING | 17 | ||||
| -rw-r--r-- | src/CMakeLists.txt | 51 | ||||
| -rw-r--r-- | src/client.c | 1101 | ||||
| -rw-r--r-- | src/deserialize.c | 278 | ||||
| -rw-r--r-- | src/deserialize.h | 9 | ||||
| -rw-r--r-- | src/log.h | 57 | ||||
| -rw-r--r-- | src/misc.c | 48 | ||||
| -rw-r--r-- | src/misc.h | 18 | ||||
| -rw-r--r-- | src/mqtt.h | 98 | ||||
| -rw-r--r-- | src/packet.c | 104 | ||||
| -rw-r--r-- | src/packet.h | 123 | ||||
| -rw-r--r-- | src/queue.h | 846 | ||||
| -rw-r--r-- | src/serialize.c | 309 | ||||
| -rw-r--r-- | src/serialize.h | 9 | ||||
| -rw-r--r-- | src/socket.c | 92 | ||||
| -rw-r--r-- | src/socket.h | 12 | ||||
| -rw-r--r-- | src/socketstream.c | 75 | ||||
| -rw-r--r-- | src/socketstream.h | 16 | ||||
| -rw-r--r-- | src/stream.c | 77 | ||||
| -rw-r--r-- | src/stream.h | 50 | ||||
| -rw-r--r-- | src/stream_mqtt.c | 98 | ||||
| -rw-r--r-- | src/stream_mqtt.h | 17 | ||||
| -rw-r--r-- | src/stringbuf.c | 73 | ||||
| -rw-r--r-- | src/stringbuf.h | 31 | ||||
| -rw-r--r-- | tools/CMakeLists.txt | 7 | ||||
| -rw-r--r-- | tools/amalgamate.py | 90 | ||||
| -rw-r--r-- | tools/getopt.c | 358 | ||||
| -rw-r--r-- | tools/getopt.h | 175 | ||||
| -rw-r--r-- | tools/pub.c | 96 | ||||
| -rw-r--r-- | tools/sub.c | 103 |
31 files changed, 4447 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..0db2d0c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,9 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 3.0) +PROJECT(mqtt C) + +ADD_SUBDIRECTORY(src) + +OPTION(BUILD_TOOLS "Build tools" ON) +IF(BUILD_TOOLS) + ADD_SUBDIRECTORY(tools) +ENDIF() @@ -0,0 +1,17 @@ +Copyright (c) 2017 Oskari Timperi + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + +1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. +2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. +3. This notice may not be removed or altered from any source distribution. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..7832b83 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,51 @@ +ADD_LIBRARY(mqtt STATIC + client.c + deserialize.c + misc.c + packet.c + serialize.c + socket.c + socketstream.c + stream.c + stream_mqtt.c + stringbuf.c +) + +IF(MSVC) +ELSEIF(CMAKE_COMPILER_IS_GNUCC OR (CMAKE_C_COMPILER_ID MATCHES Clang)) + TARGET_COMPILE_OPTIONS(mqtt PRIVATE -Wall -Wextra $<$<CONFIG:Debug>:-O0>) +ENDIF() + +SET(MQTT_LOG_LEVEL "" CACHE STRING + "(DEBUG|INFO|WARNING|ERROR) or leave empty for no logging") + +IF(MQTT_LOG_LEVEL) + TARGET_COMPILE_DEFINITIONS(mqtt PRIVATE + $<$<CONFIG:Debug>:LOG_LEVEL=LOG_LEVEL_${MQTT_LOG_LEVEL}>) +ENDIF() + +OPTION(MQTT_STREAM_HEXDUMP_READ "Hexdump all read data to stdout" OFF) +OPTION(MQTT_STREAM_HEXDUMP_WRITE "Hexdump all written data to stdout" OFF) + +IF(MQTT_STREAM_HEXDUMP_READ) + TARGET_COMPILE_DEFINITIONS(mqtt PRIVATE STREAM_HEXDUMP_READ) +ENDIF() + +IF(MQTT_STREAM_HEXDUMP_WRITE) + TARGET_COMPILE_DEFINITIONS(mqtt PRIVATE STREAM_HEXDUMP_WRITE) +ENDIF() + +TARGET_INCLUDE_DIRECTORIES(mqtt INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) + +OPTION(MQTT_AMALGAMATE "Create an amalgamation of all the sources" OFF) +IF(MQTT_AMALGAMATE) + SET(AMALGAMATION_DIR ${PROJECT_SOURCE_DIR}/amalgamation CACHE PATH "Where to output the amalgamation") + SET(AMALGAMATION_TOOL ${PROJECT_SOURCE_DIR}/tools/amalgamate.py) + FIND_PROGRAM(PYTHON python) + ADD_CUSTOM_COMMAND(OUTPUT ${AMALGAMATION_DIR}/mqtt.c + COMMAND ${CMAKE_COMMAND} -E make_directory ${AMALGAMATION_DIR} + COMMAND ${PYTHON} ${AMALGAMATION_TOOL} ${AMALGAMATION_DIR}/mqtt.c + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/mqtt.h ${AMALGAMATION_DIR} + ) + ADD_CUSTOM_TARGET(amalgamate DEPENDS ${AMALGAMATION_DIR}/mqtt.c) +ENDIF() diff --git a/src/client.c b/src/client.c new file mode 100644 index 0000000..f790f42 --- /dev/null +++ b/src/client.c @@ -0,0 +1,1101 @@ +#include "mqtt.h" +#include "packet.h" +#include "stream.h" +#include "socketstream.h" +#include "socket.h" +#include "misc.h" +#include "serialize.h" +#include "deserialize.h" +#include "log.h" + +#include "queue.h" + +#include <stdlib.h> +#include <stdio.h> +#include <string.h> +#include <assert.h> +#include <time.h> +#include <inttypes.h> + +#if (LOG_LEVEL == LOG_LEVEL_DEBUG) && !defined(PRId64) +#error define PRId64 for your platform +#endif + +#ifdef __APPLE__ +#include <sys/select.h> +#endif + +TAILQ_HEAD(MessageList, MqttPacket); +typedef struct MessageList MessageList; + +struct MqttClient +{ + SocketStream stream; + /* client id, NULL if we want to have server generated id */ + char *clientId; + /* set to 1 if we want to have a clean session */ + int cleanSession; + /* remote host and port */ + char *host; + short port; + /* keepalive interval in seconds */ + int keepAlive; + /* user specified data, not used by us */ + void *userData; + /* callback called after connection is made */ + MqttClientOnConnectCallback onConnect; + /* callback called after subscribe is done */ + MqttClientOnSubscribeCallback onSubscribe; + /* callback called after subscribe is done */ + MqttClientOnUnsubscribeCallback onUnsubscribe; + /* callback called when a message is received */ + MqttClientOnMessageCallback onMessage; + /* callback called after publish is done and acknowledged */ + MqttClientOnPublishCallback onPublish; + int stopped; + /* packets waiting to be sent over network */ + SIMPLEQ_HEAD(, MqttPacket) sendQueue; + /* sent messages that are not done yet */ + MessageList outMessages; + /* received messages that are not done yet */ + MessageList inMessages; + int sessionPresent; + /* when was the last packet sent */ + int64_t lastPacketSentTime; + /* next packet id */ + uint16_t packetId; + /* timeout after which to retry sending messages */ + int retryTimeout; + /* maximum number of inflight messages (not packets!) */ + int maxInflight; + /* maximum number of queued messages (not packets!) */ + int maxQueued; + /* 1 if PINGREQ is sent and we are waiting for PINGRESP, 0 otherwise */ + int pingSent; +}; + +enum MessageState +{ + MessageStateQueued = 100, + MessageStateSend, + MessageStateSent +}; + +static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); +static int MqttClientQueueSimplePacket(MqttClient *client, int type); +static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet); +static int MqttClientRecvPacket(MqttClient *client); +static uint16_t MqttClientNextPacketId(MqttClient *client); +static void MqttClientProcessMessageQueue(MqttClient *client); + +static inline int MqttClientInflightMessageCount(MqttClient *client) +{ + MqttPacket *packet; + int queued = 0; + int inMessagesCount = 0; + int outMessagesCount = 0; + + TAILQ_FOREACH(packet, &client->outMessages, messages) + { + if (packet->state == MessageStateQueued) + { + ++queued; + } + + ++outMessagesCount; + } + + TAILQ_FOREACH(packet, &client->inMessages, messages) + { + ++inMessagesCount; + } + + return inMessagesCount + outMessagesCount - queued; +} + +static char *CopyString(const char *s, int n) +{ + char *result = NULL; + + if (n < 0) + n = strlen(s); + + result = malloc(n+1); + + assert(result != NULL); + + memcpy(result, s, n); + result[n] = '\0'; + + return result; +} + +MqttClient *MqttClientNew(const char *clientId, int cleanSession) +{ + MqttClient *client; + + client = calloc(1, sizeof(*client)); + + if (!client) + { + return NULL; + } + + if (clientId == NULL) + { + client->clientId = CopyString("", 0); + } + else + { + client->clientId = CopyString(clientId, -1); + } + + client->cleanSession = cleanSession; + + client->stream.sock = -1; + + client->retryTimeout = 20; + + client->maxQueued = 0; + client->maxInflight = 20; + + TAILQ_INIT(&client->outMessages); + TAILQ_INIT(&client->inMessages); + SIMPLEQ_INIT(&client->sendQueue); + + return client; +} + +void MqttClientFree(MqttClient *client) +{ + if (client->clientId) + { + free(client->clientId); + } + + if (client->host) + { + free(client->host); + } + + free(client); +} + +void MqttClientSetUserData(MqttClient *client, void *userData) +{ + assert(client != NULL); + client->userData = userData; +} + +void *MqttClientGetUserData(MqttClient *client) +{ + assert(client != NULL); + return client->userData; +} + +void MqttClientSetOnConnect(MqttClient *client, MqttClientOnConnectCallback cb) +{ + assert(client != NULL); + client->onConnect = cb; +} + +void MqttClientSetOnSubscribe(MqttClient *client, + MqttClientOnSubscribeCallback cb) +{ + assert(client != NULL); + client->onSubscribe = cb; +} + +void MqttClientSetOnUnsubscribe(MqttClient *client, + MqttClientOnUnsubscribeCallback cb) +{ + assert(client != NULL); + client->onUnsubscribe = cb; +} + +void MqttClientSetOnMessage(MqttClient *client, + MqttClientOnMessageCallback cb) +{ + assert(client != NULL); + client->onMessage = cb; +} + +void MqttClientSetOnPublish(MqttClient *client, + MqttClientOnPublishCallback cb) +{ + assert(client != NULL); + client->onPublish = cb; +} + +int MqttClientConnect(MqttClient *client, const char *host, short port, + int keepAlive) +{ + int sock; + MqttPacketConnect *packet; + + assert(client != NULL); + assert(host != NULL); + + client->host = CopyString(host, -1); + client->port = port; + client->keepAlive = keepAlive; + + if (keepAlive < 0) + { + LOG_ERROR("invalid keepAlive: %d", keepAlive); + return -1; + } + + LOG_DEBUG("connecting"); + + if ((sock = SocketConnect(host, port)) == -1) + { + LOG_ERROR("SocketConnect failed!"); + return -1; + } + + if (SocketStreamOpen(&client->stream, sock) == -1) + { + return -1; + } + + packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect); + + if (!packet) + return -1; + + if (client->cleanSession) + { + packet->connectFlags |= 0x02; + } + + packet->keepAlive = client->keepAlive; + + if (StringBufInitFromCString(&packet->clientId, client->clientId, -1) == -1) + { + free(packet); + return -1; + } + + MqttClientQueuePacket(client, &packet->base); + + return 0; +} + +int MqttClientDisconnect(MqttClient *client) +{ + LOG_DEBUG("disconnecting"); + return MqttClientQueueSimplePacket(client, MqttPacketTypeDisconnect); +} + +int MqttClientRunOnce(MqttClient *client) +{ + fd_set rfd, wfd; + struct timeval tv; + int rv; + + assert(client != NULL); + + if (client->stream.sock == -1) + { + LOG_ERROR("invalid socket"); + return -1; + } + + FD_ZERO(&rfd); + FD_ZERO(&wfd); + + FD_SET(client->stream.sock, &rfd); + + /* Handle outMessages and inMessages, moving queued messages to sendQueue + if there are less than maxInflight number of messages in flight */ + MqttClientProcessMessageQueue(client); + + if (SIMPLEQ_EMPTY(&client->sendQueue)) + { + LOG_DEBUG("nothing to write"); + } + else + { + FD_SET(client->stream.sock, &wfd); + } + + // TODO: break select when queuing packets (need to protect queue with mutex + // to allow queuing packets from another thread) + + memset(&tv, 0, sizeof(tv)); + tv.tv_sec = client->keepAlive; + tv.tv_usec = 0; + + LOG_DEBUG("selecting"); + rv = select(client->stream.sock+1, &rfd, &wfd, NULL, &tv); + + if (rv == -1) + { + LOG_ERROR("select failed"); + return -1; + } + else if (rv) + { + LOG_DEBUG("select rv=%d", rv); + + if (FD_ISSET(client->stream.sock, &wfd)) + { + MqttPacket *packet; + + LOG_DEBUG("socket writable"); + + packet = SIMPLEQ_FIRST(&client->sendQueue); + + if (packet) + { + SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); + + if (MqttClientSendPacket(client, packet) == -1) + { + LOG_ERROR("MqttClientSendPacket failed"); + client->stopped = 1; + } + } + } + + if (FD_ISSET(client->stream.sock, &rfd)) + { + LOG_DEBUG("socket readable"); + + if (MqttClientRecvPacket(client) == -1) + { + LOG_ERROR("MqttClientRecvPacket failed"); + client->stopped = 1; + } + } + } + else + { + LOG_DEBUG("select timeout"); + + if (client->pingSent) + { + LOG_ERROR("no PINGRESP received in time"); + client->pingSent = 0; + client->stopped = 1; + } + else if (SIMPLEQ_EMPTY(&client->sendQueue)) + { + int64_t elapsed = GetCurrentTime() - client->lastPacketSentTime; + if (elapsed/1000 >= client->keepAlive) + { + MqttClientQueueSimplePacket(client, MqttPacketTypePingReq); + client->pingSent = 1; + } + } + } + + if (client->stopped) + { + SocketDisconnect(client->stream.sock); + client->stream.sock = -1; + } + + return 0; +} + +int MqttClientRun(MqttClient *client) +{ + assert(client != NULL); + + while (!client->stopped) + { + if (MqttClientRunOnce(client) == -1) + return -1; + } + + return 0; +} + +int MqttClientSubscribe(MqttClient *client, const char *topicFilter, + int qos) +{ + MqttPacketSubscribe *packet = NULL; + + assert(client != NULL); + assert(topicFilter != NULL); + assert(qos >= 0 && qos <= 2); + + packet = (MqttPacketSubscribe *) MqttPacketWithIdNew( + MqttPacketTypeSubscribe, MqttClientNextPacketId(client)); + + if (!packet) + return -1; + + if (StringBufInitFromCString(&packet->topicFilter, topicFilter, -1) == -1) + { + MqttPacketFree((MqttPacket *) packet); + return -1; + } + + packet->qos = qos; + + MqttClientQueuePacket(client, (MqttPacket *) packet); + + TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + + return MqttPacketId(packet); +} + +int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter) +{ + MqttPacketUnsubscribe *packet = NULL; + + assert(client != NULL); + assert(topicFilter != NULL); + + packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew( + MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client)); + + if (StringBufInitFromCString(&packet->topicFilter, topicFilter, -1) == -1) + { + MqttPacketFree((MqttPacket *) packet); + return -1; + } + + MqttClientQueuePacket(client, (MqttPacket *) packet); + + TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + + return MqttPacketId(packet); +} + +static inline int MqttClientOutMessagesLen(MqttClient *client) +{ + MqttPacket *packet; + int count = 0; + TAILQ_FOREACH(packet, &client->outMessages, messages) + { + ++count; + } + return count; +} + +int MqttClientPublish(MqttClient *client, int qos, int retain, + const char *topic, const void *data, size_t size) +{ + MqttPacketPublish *packet; + + assert(client != NULL); + + /* first check if the queue is already full */ + if (qos > 0 && client->maxQueued > 0 && + MqttClientOutMessagesLen(client) >= client->maxQueued) + { + return -1; + } + + if (qos > 0) + { + packet = (MqttPacketPublish *) MqttPacketWithIdNew( + MqttPacketTypePublish, MqttClientNextPacketId(client)); + } + else + { + packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish); + } + + if (!packet) + return -1; + + packet->qos = qos; + packet->retain = retain; + + if (StringBufInitFromCString(&packet->topicName, topic, -1) == -1) + { + MqttPacketFree((MqttPacket *) packet); + return -1; + } + + if (StringBufInitFromData(&packet->message, data, size) == -1) + { + MqttPacketFree((MqttPacket *) packet); + return -1; + } + + if (qos > 0) + { + /* check how many messages there are coming in and going out currently + that are not yet done */ + if (client->maxInflight == 0 || + MqttClientInflightMessageCount(client) < client->maxInflight) + { + LOG_DEBUG("setting message (%d) state to MessageStateSend", + MqttPacketId(packet)); + packet->base.state = MessageStateSend; + } + else + { + LOG_DEBUG("setting message (%d) state to MessageStateQueued", + MqttPacketId(packet)); + packet->base.state = MessageStateQueued; + } + + /* add the message to the outMessages queue to wait for processing */ + TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, + messages); + } + else + { + MqttClientQueuePacket(client, (MqttPacket *) packet); + } + + if (qos > 0) + return MqttPacketId(packet); + + return 0; +} + +int MqttClientPublishCString(MqttClient *client, int qos, int retain, + const char *topic, const char *msg) +{ + return MqttClientPublish(client, qos, retain, topic, msg, strlen(msg)); +} + +void MqttClientSetPublishRetryTimeout(MqttClient *client, int timeout) +{ + assert(client != NULL); + client->retryTimeout = timeout; +} + +void MqttClientSetMaxMessagesInflight(MqttClient *client, int max) +{ + assert(client != NULL); + client->maxInflight = max; +} + +void MqttClientSetMaxQueuedMessages(MqttClient *client, int max) +{ + assert(client != NULL); + client->maxQueued = max; +} + +static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet) +{ + assert(client != NULL); + LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type)); + SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue); +} + +static int MqttClientQueueSimplePacket(MqttClient *client, int type) +{ + MqttPacket *packet = MqttPacketNew(type); + if (!packet) + return -1; + MqttClientQueuePacket(client, packet); + return 0; +} + +static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet) +{ + if (MqttPacketSerialize(packet, &client->stream.base) == -1) + return -1; + + packet->sentAt = GetCurrentTime(); + client->lastPacketSentTime = packet->sentAt; + + if (packet->type == MqttPacketTypeDisconnect) + { + client->stopped = 1; + } + + /* If the packet is not on any message list, it can be removed after + sending. */ + if (TAILQ_NEXT(packet, messages) == NULL && + TAILQ_PREV(packet, MessageList, messages) == NULL && + TAILQ_FIRST(&client->inMessages) != packet && + TAILQ_FIRST(&client->outMessages) != packet) + { + LOG_DEBUG("freeing packet %s after sending", + MqttPacketName(MqttPacketType(packet))); + MqttPacketFree(packet); + } + + return 0; +} + +static void MqttClientHandleConnAck(MqttClient *client, + MqttPacketConnAck *packet) +{ + client->sessionPresent = packet->connAckFlags & 1; + + LOG_DEBUG("sessionPresent:%d", client->sessionPresent); + + if (client->onConnect) + { + LOG_DEBUG("calling onConnect rc:%d", packet->returnCode); + client->onConnect(client, packet->returnCode, client->sessionPresent); + } +} + +static void MqttClientHandlePingResp(MqttClient *client) +{ + LOG_DEBUG("got ping response"); + client->pingSent = 0; +} + +static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet) +{ + MqttPacket *sub; + + assert(client != NULL); + assert(packet != NULL); + + TAILQ_FOREACH(sub, &client->outMessages, messages) + { + if (MqttPacketType(sub) == MqttPacketTypeSubscribe && + MqttPacketId(sub) == MqttPacketId(packet)) + { + break; + } + } + + if (!sub) + { + LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet)); + client->stopped = 1; + } + else + { + TAILQ_REMOVE(&client->outMessages, sub, messages); + MqttPacketFree(sub); + + if (client->onSubscribe) + { + LOG_DEBUG("calling onSubscribe id:%d rc:%d", MqttPacketId(packet), + packet->returnCode); + client->onSubscribe(client, MqttPacketId(packet), + packet->returnCode); + } + } + + MqttPacketFree((MqttPacket *) packet); +} + +static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet) +{ + if (MqttPacketPublishQos(packet) == 2) + { + /* Check if we have sent a PUBREC previously with the same id. If we + have, we have to resend the PUBREC. We must not call the onMessage + callback again. */ + + MqttPacket *pubRec; + + TAILQ_FOREACH(pubRec, &client->inMessages, messages) + { + if (MqttPacketId(pubRec) == MqttPacketId(packet) && + MqttPacketType(pubRec) == MqttPacketTypePubRec) + { + break; + } + } + + if (pubRec) + { + LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet)); + // MqttPacketWithId *pubRec = (MqttPacketWithId *) pubRecNode->packet; + MqttClientQueuePacket(client, pubRec); + MqttPacketFree((MqttPacket *) packet); + return; + } + } + + if (client->onMessage) + { + LOG_DEBUG("calling onMessage"); + client->onMessage(client, + packet->topicName.data, + packet->message.data, + packet->message.len); + } + + if (MqttPacketPublishQos(packet) > 0) + { + int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck : + MqttPacketTypePubRec; + + MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet)); + + if (MqttPacketPublishQos(packet) == 2) + { + /* append to inMessages as we need a reply to this response */ + TAILQ_INSERT_TAIL(&client->inMessages, resp, messages); + } + + MqttClientQueuePacket(client, resp); + } + + MqttPacketFree((MqttPacket *) packet); +} + +static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet) +{ + MqttPacket *pub; + + TAILQ_FOREACH(pub, &client->outMessages, messages) + { + if (MqttPacketId(pub) == MqttPacketId(packet) && + MqttPacketType(pub) == MqttPacketTypePublish) + { + break; + } + } + + if (!pub) + { + LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); + client->stopped = 1; + } + else + { + TAILQ_REMOVE(&client->outMessages, pub, messages); + MqttPacketFree(pub); + + if (client->onPublish) + { + client->onPublish(client, MqttPacketId(packet)); + } + } + + MqttPacketFree(packet); +} + +static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet) +{ + MqttPacket *pub; + + assert(client != NULL); + assert(packet != NULL); + + TAILQ_FOREACH(pub, &client->outMessages, messages) + { + if (MqttPacketId(pub) == MqttPacketId(packet) && + MqttPacketType(pub) == MqttPacketTypePublish) + { + break; + } + } + + if (!pub) + { + LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); + client->stopped = 1; + } + else + { + MqttPacket *pubRel; + + TAILQ_REMOVE(&client->outMessages, pub, messages); + MqttPacketFree(pub); + + pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet)); + pubRel->state = MessageStateSend; + + TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages); + } + + MqttPacketFree(packet); +} + +static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet) +{ + MqttPacket *pubRec; + + assert(client != NULL); + assert(packet != NULL); + + TAILQ_FOREACH(pubRec, &client->inMessages, messages) + { + if (MqttPacketId(pubRec) == MqttPacketId(packet) && + MqttPacketType(pubRec) == MqttPacketTypePublish) + { + break; + } + } + + if (!pubRec) + { + MqttPacket *pubComp; + + TAILQ_FOREACH(pubComp, &client->inMessages, messages) + { + if (MqttPacketId(pubComp) == MqttPacketId(packet) && + MqttPacketType(pubComp) == MqttPacketTypePubComp) + { + break; + } + } + + if (pubComp) + { + MqttClientQueuePacket(client, pubComp); + } + else + { + LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet)); + client->stopped = 1; + } + } + else + { + MqttPacket *pubComp; + + TAILQ_REMOVE(&client->inMessages, pubRec, messages); + MqttPacketFree(pubRec); + + pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp, + MqttPacketId(packet)); + + TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages); + + MqttClientQueuePacket(client, pubComp); + } + + MqttPacketFree(packet); +} + +static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet) +{ + MqttPacket *pubRel; + + TAILQ_FOREACH(pubRel, &client->outMessages, messages) + { + if (MqttPacketId(pubRel) == MqttPacketId(packet) && + MqttPacketType(pubRel) == MqttPacketTypePubRel) + { + break; + } + } + + if (!pubRel) + { + LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet)); + client->stopped = 1; + } + else + { + TAILQ_REMOVE(&client->outMessages, pubRel, messages); + MqttPacketFree(pubRel); + + if (client->onPublish) + { + LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet)); + client->onPublish(client, MqttPacketId(packet)); + } + } + + MqttPacketFree(packet); +} + +static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet) +{ + MqttPacket *sub; + + assert(client != NULL); + assert(packet != NULL); + + TAILQ_FOREACH(sub, &client->outMessages, messages) + { + if (MqttPacketId(sub) == MqttPacketId(packet) && + MqttPacketType(sub) == MqttPacketTypeUnsubscribe) + { + break; + } + } + + if (!sub) + { + LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet)); + client->stopped = 1; + } + else + { + TAILQ_REMOVE(&client->outMessages, sub, messages); + MqttPacketFree(sub); + + if (client->onUnsubscribe) + { + LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet)); + client->onUnsubscribe(client, MqttPacketId(packet)); + } + } + + MqttPacketFree(packet); +} + +static int MqttClientRecvPacket(MqttClient *client) +{ + MqttPacket *packet = NULL; + + if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1) + return -1; + + LOG_DEBUG("received packet %s", MqttPacketName(packet->type)); + + switch (MqttPacketType(packet)) + { + case MqttPacketTypeConnAck: + MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet); + break; + + case MqttPacketTypePingResp: + MqttClientHandlePingResp(client); + break; + + case MqttPacketTypeSubAck: + MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet); + break; + + case MqttPacketTypePublish: + MqttClientHandlePublish(client, (MqttPacketPublish *) packet); + break; + + case MqttPacketTypePubAck: + MqttClientHandlePubAck(client, packet); + break; + + case MqttPacketTypePubRec: + MqttClientHandlePubRec(client, packet); + break; + + case MqttPacketTypePubRel: + MqttClientHandlePubRel(client, packet); + break; + + case MqttPacketTypePubComp: + MqttClientHandlePubComp(client, packet); + break; + + case MqttPacketTypeUnsubAck: + MqttClientHandleUnsubAck(client, packet); + break; + + default: + LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet)); + break; + } + + return 0; +} + +static uint16_t MqttClientNextPacketId(MqttClient *client) +{ + uint16_t id; + assert(client != NULL); + id = ++client->packetId; + if (id == 0) + id = ++client->packetId; + return id; +} + +static int64_t MqttPacketTimeSinceSent(MqttPacket *packet) +{ + int64_t now = GetCurrentTime(); + return now - packet->sentAt; +} + +static void MqttClientProcessInMessages(MqttClient *client) +{ + MqttPacket *packet, *next; + + LOG_DEBUG("processing inMessages"); + + TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next) + { + LOG_DEBUG("packet type:%s id:%d", + MqttPacketName(MqttPacketType(packet)), + MqttPacketId(packet)); + + if (MqttPacketType(packet) == MqttPacketTypePubComp) + { + int64_t elapsed = MqttPacketTimeSinceSent(packet); + if (packet->sentAt > 0 && + elapsed >= client->retryTimeout*1000) + { + LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64, + MqttPacketId(packet), elapsed); + + TAILQ_REMOVE(&client->inMessages, packet, messages); + + MqttPacketFree(packet); + } + } + } +} + +static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet) +{ + if (packet->sentAt > 0 && + MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000) + { + return 1; + } + + return 0; +} + +static void MqttClientProcessOutMessages(MqttClient *client) +{ + MqttPacket *packet, *next; + int inflight = MqttClientInflightMessageCount(client); + + LOG_DEBUG("processing outMessages inflight:%d", inflight); + + TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next) + { + LOG_DEBUG("packet type:%s id:%d state:%d", + MqttPacketName(MqttPacketType(packet)), + MqttPacketId(packet), + packet->state); + + switch (packet->state) + { + case MessageStateQueued: + if (inflight >= client->maxInflight) + { + LOG_DEBUG("cannot dequeue %s/%d", + MqttPacketName(MqttPacketType(packet)), + MqttPacketId(packet)); + break; + } + else + { + /* If there's less than maxInflight messages currently + inflight, we can dequeue some messages by falling + through to MessageStateSend. */ + LOG_DEBUG("dequeuing %s (%d)", + MqttPacketName(MqttPacketType(packet)), + MqttPacketId(packet)); + ++inflight; + } + + case MessageStateSend: + packet->state = MessageStateSent; + MqttClientQueuePacket(client, packet); + break; + + case MessageStateSent: + if (MqttPacketShouldResend(client, packet)) + { + packet->state = MessageStateSend; + } + break; + + default: + break; + } + } +} + +static void MqttClientProcessMessageQueue(MqttClient *client) +{ + MqttClientProcessInMessages(client); + MqttClientProcessOutMessages(client); +} diff --git a/src/deserialize.c b/src/deserialize.c new file mode 100644 index 0000000..19be4ce --- /dev/null +++ b/src/deserialize.c @@ -0,0 +1,278 @@ +#include "deserialize.h" +#include "packet.h" +#include "stream_mqtt.h" +#include "log.h" + +#include <stdlib.h> +#include <assert.h> + +typedef int (*MqttPacketDeserializeFunc)(MqttPacket **packet, Stream *stream); + +static int MqttPacketWithIdDeserialize(MqttPacket **packet, Stream *stream) +{ + size_t remainingLength = 0; + + if (StreamReadRemainingLength(&remainingLength, stream) == -1) + return -1; + + if (remainingLength != 2) + return -1; + + if (StreamReadUint16Be(&(*packet)->id, stream) == -1) + return -1; + + return 0; +} + +static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stream) +{ + size_t remainingLength = 0; + + if (StreamReadRemainingLength(&remainingLength, stream) == -1) + return -1; + + if (remainingLength != 2) + return -1; + + if (StreamRead(&(*packet)->connAckFlags, 1, stream) != 1) + return -1; + + if (StreamRead(&(*packet)->returnCode, 1, stream) != 1) + return -1; + + return 0; +} + +static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream) +{ + size_t remainingLength = 0; + + if (StreamReadRemainingLength(&remainingLength, stream) == -1) + return -1; + + // 2 bytes for packet id and 1 byte for single return code + if (remainingLength != 3) + return -1; + + if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) + return -1; + + if (StreamRead(&((*packet)->returnCode), 1, stream) == -1) + return -1; + + return 0; +} + +static int MqttPacketTypeUnsubAckDeserialize(MqttPacket **packet, Stream *stream) +{ + size_t remainingLength = 0; + + if (StreamReadRemainingLength(&remainingLength, stream) == -1) + return -1; + + if (remainingLength != 2) + return -1; + + if (StreamReadUint16Be(&(*packet)->id, stream) == -1) + return -1; + + return 0; +} + +static int MqttPacketPublishDeserialize(MqttPacketPublish **packet, Stream *stream) +{ + size_t remainingLength = 0; + size_t payloadSize = 0; + + if (StreamReadRemainingLength(&remainingLength, stream) == -1) + return -1; + + if (StreamReadMqttStringBuf(&(*packet)->topicName, stream) == -1) + return -1; + + LOG_DEBUG("remainingLength:%lu", remainingLength); + + payloadSize = remainingLength - (*packet)->topicName.len - 2; + + LOG_DEBUG("qos:%d payloadSize:%lu", MqttPacketPublishQos(*packet), + payloadSize); + + if (MqttPacketHasId((const MqttPacket *) *packet)) + { + LOG_DEBUG("packet has id"); + payloadSize -= 2; + if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) + { + return -1; + } + } + + LOG_DEBUG("reading payload payloadSize:%lu\n", payloadSize); + + if (StringBufInit(&((*packet)->message), payloadSize) == -1) + return -1; + + if (StreamRead((*packet)->message.data, payloadSize, stream) == -1) + return -1; + + (*packet)->message.len = payloadSize; + + return 0; +} + +static int MqttPacketGenericDeserializer(MqttPacket **packet, Stream *stream) +{ + size_t remainingLength = 0; + char buffer[256]; + + (void) packet; + + if (StreamReadRemainingLength(&remainingLength, stream) == -1) + return -1; + + while (remainingLength > 0) + { + size_t l = sizeof(buffer); + + if (remainingLength < l) + l = remainingLength; + + if (StreamRead(buffer, l, stream) != (int64_t) l) + return -1; + + remainingLength -= l; + } + + return 0; +} + +static int ValidateFlags(int type, int flags) +{ + int rv = 0; + + switch (type) + { + case MqttPacketTypePublish: + { + int qos = (flags >> 1) & 2; + if (qos >= 0 && qos <= 2) + rv = 1; + break; + } + + case MqttPacketTypePubRel: + case MqttPacketTypeSubscribe: + case MqttPacketTypeUnsubscribe: + if (flags == 2) + { + rv = 1; + } + break; + + default: + if (flags == 0) + { + rv = 1; + } + break; + } + + return rv; +} + +int MqttPacketDeserialize(MqttPacket **packet, Stream *stream) +{ + MqttPacketDeserializeFunc deserializer = NULL; + char typeAndFlags; + int type; + int flags; + int rv; + + if (StreamRead(&typeAndFlags, 1, stream) != 1) + return -1; + + type = (typeAndFlags & 0xF0) >> 4; + flags = (typeAndFlags & 0x0F); + + if (!ValidateFlags(type, flags)) + { + return -1; + } + + switch (type) + { + case MqttPacketTypeConnect: + break; + + case MqttPacketTypeConnAck: + deserializer = (MqttPacketDeserializeFunc) MqttPacketConnAckDeserialize; + break; + + case MqttPacketTypePublish: + deserializer = (MqttPacketDeserializeFunc) MqttPacketPublishDeserialize; + break; + + case MqttPacketTypePubAck: + deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; + break; + + case MqttPacketTypePubRec: + deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; + break; + + case MqttPacketTypePubRel: + deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; + break; + + case MqttPacketTypePubComp: + deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; + break; + + case MqttPacketTypeSubscribe: + break; + + case MqttPacketTypeSubAck: + deserializer = (MqttPacketDeserializeFunc) MqttPacketSubAckDeserialize; + break; + + case MqttPacketTypeUnsubscribe: + break; + + case MqttPacketTypeUnsubAck: + deserializer = (MqttPacketDeserializeFunc) MqttPacketTypeUnsubAckDeserialize; + break; + + case MqttPacketTypePingReq: + break; + + case MqttPacketTypePingResp: + break; + + case MqttPacketTypeDisconnect: + break; + + default: + return -1; + } + + if (!deserializer) + { + deserializer = MqttPacketGenericDeserializer; + } + + *packet = MqttPacketNew(type); + + if (!*packet) + return -1; + + if (type == MqttPacketTypePublish) + { + MqttPacketPublishDup(*packet) = (flags >> 3) & 1; + MqttPacketPublishQos(*packet) = (flags >> 1) & 3; + MqttPacketPublishRetain(*packet) = flags & 1; + } + + rv = deserializer(packet, stream); + + return rv; +} diff --git a/src/deserialize.h b/src/deserialize.h new file mode 100644 index 0000000..826e42a --- /dev/null +++ b/src/deserialize.h @@ -0,0 +1,9 @@ +#ifndef DESERIALIZE_H +#define DESERIALIZE_H + +typedef struct MqttPacket MqttPacket; +typedef struct Stream Stream; + +int MqttPacketDeserialize(MqttPacket **packet, Stream *stream); + +#endif diff --git a/src/log.h b/src/log.h new file mode 100644 index 0000000..a9df317 --- /dev/null +++ b/src/log.h @@ -0,0 +1,57 @@ +#ifndef LOG_H +#define LOG_H + +#define LOG_LEVEL_DEBUG 0 +#define LOG_LEVEL_INFO 1 +#define LOG_LEVEL_WARNING 2 +#define LOG_LEVEL_ERROR 3 + +#if !defined(LOG_LEVEL) +#define LOG_LEVEL (LOG_LEVEL_ERROR+1) +#else +#include <stdio.h> +#include <string.h> +static inline const char *log_basename(const char *s) +{ +#if !defined(_WIN32) + const char *p = strrchr(s, '/'); +#else + const char *p = strrchr(s, '\\'); +#endif + + if (p) + return p+1; + + return s; +} +#endif + +#define LOG_DOLOG(level, fmt, ...) \ + fprintf(stderr, "%s %s %s:%d " fmt "\n", \ + #level, __FUNCTION__, log_basename(__FILE__), __LINE__, ##__VA_ARGS__) + +#if (LOG_LEVEL <= LOG_LEVEL_DEBUG) +#define LOG_DEBUG(FMT, ...) LOG_DOLOG(DEBUG, FMT, ##__VA_ARGS__) +#else +#define LOG_DEBUG(FMT, ...) +#endif + +#if (LOG_LEVEL <= LOG_LEVEL_INFO) +#define LOG_INFO(FMT, ...) LOG_DOLOG(INFO, FMT, ##__VA_ARGS__) +#else +#define LOG_INFO(FMT, ...) +#endif + +#if (LOG_LEVEL <= LOG_LEVEL_WARNING) +#define LOG_WARNING(FMT, ...) LOG_DOLOG(WARNING, FMT, ##__VA_ARGS__) +#else +#define LOG_WARNING(FMT, ...) +#endif + +#if (LOG_LEVEL <= LOG_LEVEL_ERROR) +#define LOG_ERROR(FMT, ...) LOG_DOLOG(ERROR, FMT, ##__VA_ARGS__) +#else +#define LOG_ERROR(FMT, ...) +#endif + +#endif diff --git a/src/misc.c b/src/misc.c new file mode 100644 index 0000000..ddd857a --- /dev/null +++ b/src/misc.c @@ -0,0 +1,48 @@ +#include "misc.h" + +#include <stdio.h> +#include <time.h> + +#if defined(_WIN32) +#error GetCurrentTime implementation missing +#else +int64_t GetCurrentTime() +{ + struct timespec t; + + if (clock_gettime(CLOCK_MONOTONIC, &t) == -1) + return -1; + + return ((int64_t) t.tv_sec * 1000) + (int64_t) t.tv_nsec / 1000 / 1000; +} +#endif + +// https://gist.github.com/ccbrown/9722406 +void DumpHex(const void* data, size_t size) { + char ascii[17]; + size_t i, j; + ascii[16] = '\0'; + for (i = 0; i < size; ++i) { + printf("%02X ", ((unsigned char*)data)[i]); + if (((unsigned char*)data)[i] >= ' ' && ((unsigned char*)data)[i] <= '~') { + ascii[i % 16] = ((unsigned char*)data)[i]; + } else { + ascii[i % 16] = '.'; + } + if ((i+1) % 8 == 0 || i+1 == size) { + printf(" "); + if ((i+1) % 16 == 0) { + printf("| %s \n", ascii); + } else if (i+1 == size) { + ascii[(i+1) % 16] = '\0'; + if ((i+1) % 16 <= 8) { + printf(" "); + } + for (j = (i+1) % 16; j < 16; ++j) { + printf(" "); + } + printf("| %s \n", ascii); + } + } + } +} diff --git a/src/misc.h b/src/misc.h new file mode 100644 index 0000000..133b7eb --- /dev/null +++ b/src/misc.h @@ -0,0 +1,18 @@ +#ifndef MISC_H +#define MISC_H + +#include <stdint.h> +#include <stdlib.h> + +/* + Returns the current time as milliseconds. The return value can only be + compared to another return value of the function. +*/ +int64_t GetCurrentTime(); + +/* + Simple hexdump to stdout. +*/ +void DumpHex(const void* data, size_t size); + +#endif diff --git a/src/mqtt.h b/src/mqtt.h new file mode 100644 index 0000000..84b42be --- /dev/null +++ b/src/mqtt.h @@ -0,0 +1,98 @@ +#ifndef MQTT_H +#define MQTT_H + +#if defined(__cplusplus) +extern "C" { +#endif + +#include <stdlib.h> + +typedef enum MqttConnectionStatus +{ + MqttConnectionAccepted = 0, + MqttConnectionInvalidProtocolVersion, + MqttConnectionIdentifierRejected, + MqttConnectionServerUnavailable, + MqttConnectionBadAuth, + MqttConnectionNotAuthorized +} MqttConnectionStatus; + +typedef enum MqttSubscriptionStatus +{ + MqttSubscriptionQos0 = 0, + MqttSubscriptionQos1 = 1, + MqttSubscriptionQos2 = 2, + MqttSubscriptionFailure = 0x80 +} MqttSubscriptionStatus; + +typedef struct MqttClient MqttClient; + +typedef void (*MqttClientOnConnectCallback)(MqttClient *client, + MqttConnectionStatus status, + int sessionPresent); + +typedef void (*MqttClientOnSubscribeCallback)(MqttClient *client, + int id, + MqttSubscriptionStatus status); + +typedef void (*MqttClientOnUnsubscribeCallback)(MqttClient *client, int id); + +typedef void (*MqttClientOnMessageCallback)(MqttClient *client, + const char *topic, + const void *data, size_t size); + +typedef void (*MqttClientOnPublishCallback)(MqttClient *client, int id); + +MqttClient *MqttClientNew(const char *clientId, int cleanSession); + +void MqttClientFree(MqttClient *client); + +void MqttClientSetUserData(MqttClient *client, void *userData); + +void *MqttClientGetUserData(MqttClient *client); + +void MqttClientSetOnConnect(MqttClient *client, MqttClientOnConnectCallback cb); + +void MqttClientSetOnSubscribe(MqttClient *client, + MqttClientOnSubscribeCallback cb); + +void MqttClientSetOnUnsubscribe(MqttClient *client, + MqttClientOnUnsubscribeCallback cb); + +void MqttClientSetOnMessage(MqttClient *client, + MqttClientOnMessageCallback cb); + +void MqttClientSetOnPublish(MqttClient *client, + MqttClientOnPublishCallback cb); + +int MqttClientConnect(MqttClient *client, const char *host, short port, + int keepAlive); + +int MqttClientDisconnect(MqttClient *client); + +int MqttClientRunOnce(MqttClient *client); + +int MqttClientRun(MqttClient *client); + +int MqttClientSubscribe(MqttClient *client, const char *topicFilter, + int qos); + +int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter); + +int MqttClientPublish(MqttClient *client, int qos, int retain, + const char *topic, const void *data, size_t size); + +int MqttClientPublishCString(MqttClient *client, int qos, int retain, + const char *topic, const char *msg); + +void MqttClientSetPublishRetryTimeout(MqttClient *client, int timeout); + +void MqttClientSetMaxMessagesInflight(MqttClient *client, int max); + +void MqttClientSetMaxQueuedMessages(MqttClient *client, int max); + +#if defined(__cplusplus) +} +#endif + +#endif diff --git a/src/packet.c b/src/packet.c new file mode 100644 index 0000000..1e05330 --- /dev/null +++ b/src/packet.c @@ -0,0 +1,104 @@ +#include "packet.h" +#include "log.h" + +#include <string.h> +#include <stdio.h> +#include <assert.h> +#include <errno.h> + +const char *MqttPacketName(int type) +{ + switch (type) + { + case MqttPacketTypeConnect: return "CONNECT"; + case MqttPacketTypeConnAck: return "CONNACK"; + case MqttPacketTypePublish: return "PUBLISH"; + case MqttPacketTypePubAck: return "PUBACK"; + case MqttPacketTypePubRec: return "PUBREC"; + case MqttPacketTypePubRel: return "PUBREL"; + case MqttPacketTypePubComp: return "PUBCOMP"; + case MqttPacketTypeSubscribe: return "SUBSCRIBE"; + case MqttPacketTypeSubAck: return "SUBACK"; + case MqttPacketTypeUnsubscribe: return "UNSUBSCRIBE"; + case MqttPacketTypeUnsubAck: return "UNSUBACK"; + case MqttPacketTypePingReq: return "PINGREQ"; + case MqttPacketTypePingResp: return "PINGRESP"; + case MqttPacketTypeDisconnect: return "DISCONNECT"; + default: return NULL; + } +} + +static inline size_t MqttPacketStructSize(int type) +{ + switch (type) + { + case MqttPacketTypeConnect: return sizeof(MqttPacketConnect); + case MqttPacketTypeConnAck: return sizeof(MqttPacketConnAck); + case MqttPacketTypePublish: return sizeof(MqttPacketPublish); + case MqttPacketTypePubAck: + case MqttPacketTypePubRec: + case MqttPacketTypePubRel: + case MqttPacketTypePubComp: return sizeof(MqttPacket); + case MqttPacketTypeSubscribe: return sizeof(MqttPacketSubscribe); + case MqttPacketTypeSubAck: return sizeof(MqttPacketSubAck); + case MqttPacketTypeUnsubscribe: return sizeof(MqttPacketUnsubscribe); + case MqttPacketTypeUnsubAck: return sizeof(MqttPacket); + case MqttPacketTypePingReq: return sizeof(MqttPacket); + case MqttPacketTypePingResp: return sizeof(MqttPacket); + case MqttPacketTypeDisconnect: return sizeof(MqttPacket); + default: return (size_t) -1; + } +} + +MqttPacket *MqttPacketNew(int type) +{ + MqttPacket *packet = NULL; + + packet = (MqttPacket *) calloc(1, MqttPacketStructSize(type)); + if (!packet) + return NULL; + + packet->type = type; + + /* this will make sure that TAILQ_PREV does not segfault if a message + has not been added to a list at any point */ + packet->messages.tqe_prev = &packet->messages.tqe_next; + + return packet; +} + +MqttPacket *MqttPacketWithIdNew(int type, uint16_t id) +{ + MqttPacket *packet = MqttPacketNew(type); + if (!packet) + return NULL; + packet->id = id; + return packet; +} + +void MqttPacketFree(MqttPacket *packet) +{ + /* TODO: implement MqttPacketFree */ +} + +int MqttPacketHasId(const MqttPacket *packet) +{ + switch (packet->type) + { + case MqttPacketTypePublish: + return MqttPacketPublishQos(packet) > 0; + + case MqttPacketTypePubAck: + case MqttPacketTypePubRec: + case MqttPacketTypePubRel: + case MqttPacketTypePubComp: + case MqttPacketTypeSubscribe: + case MqttPacketTypeSubAck: + case MqttPacketTypeUnsubscribe: + case MqttPacketTypeUnsubAck: + return 1; + + default: + return 0; + } +} diff --git a/src/packet.h b/src/packet.h new file mode 100644 index 0000000..0f7b17e --- /dev/null +++ b/src/packet.h @@ -0,0 +1,123 @@ +#ifndef PACKET_H +#define PACKET_H + +#include <stdlib.h> +#include <stdint.h> +#include <assert.h> + +#include "stringbuf.h" + +#include "queue.h" + +enum +{ + MqttPacketTypeConnect = 0x1, + MqttPacketTypeConnAck = 0x2, + MqttPacketTypePublish = 0x3, + MqttPacketTypePubAck = 0x4, + MqttPacketTypePubRec = 0x5, + MqttPacketTypePubRel = 0x6, + MqttPacketTypePubComp = 0x7, + MqttPacketTypeSubscribe = 0x8, + MqttPacketTypeSubAck = 0x9, + MqttPacketTypeUnsubscribe = 0xA, + MqttPacketTypeUnsubAck = 0xB, + MqttPacketTypePingReq = 0xC, + MqttPacketTypePingResp = 0xD, + MqttPacketTypeDisconnect = 0xE +}; + +typedef struct MqttPacket MqttPacket; + +struct MqttPacket +{ + int type; + uint16_t id; + int state; + int flags; + int64_t sentAt; + SIMPLEQ_ENTRY(MqttPacket) sendQueue; + TAILQ_ENTRY(MqttPacket) messages; +}; + +#define MqttPacketType(packet) (((MqttPacket *) (packet))->type) + +#define MqttPacketId(packet) (((MqttPacket *) (packet))->id) + +#define MqttPacketSentAt(packet) (((MqttPacket *) (packet))->sentAt) + +typedef struct MqttPacketConnect MqttPacketConnect; + +struct MqttPacketConnect +{ + MqttPacket base; + char connectFlags; + uint16_t keepAlive; + StringBuf clientId; + StringBuf willTopic; + StringBuf willMessage; + StringBuf userName; + StringBuf password; +}; + +typedef struct MqttPacketConnAck MqttPacketConnAck; + +struct MqttPacketConnAck +{ + MqttPacket base; + unsigned char connAckFlags; + unsigned char returnCode; +}; + +typedef struct MqttPacketPublish MqttPacketPublish; + +struct MqttPacketPublish +{ + MqttPacket base; + StringBuf topicName; + StringBuf message; + char qos; + char dup; + char retain; +}; + +#define MqttPacketPublishQos(p) (((MqttPacketPublish *) p)->qos) +#define MqttPacketPublishDup(p) (((MqttPacketPublish *) p)->dup) +#define MqttPacketPublishRetain(p) (((MqttPacketPublish *) p)->retain) + +typedef struct MqttPacketSubscribe MqttPacketSubscribe; + +struct MqttPacketSubscribe +{ + MqttPacket base; + StringBuf topicFilter; + char qos; +}; + +typedef struct MqttPacketSubAck MqttPacketSubAck; + +struct MqttPacketSubAck +{ + MqttPacket base; + unsigned char returnCode; +}; + +typedef struct MqttPacketUnsubscribe MqttPacketUnsubscribe; + +struct MqttPacketUnsubscribe +{ + MqttPacket base; + StringBuf topicFilter; +}; + +const char *MqttPacketName(int type); + +MqttPacket *MqttPacketNew(int type); + +MqttPacket *MqttPacketWithIdNew(int type, uint16_t id); + +void MqttPacketFree(MqttPacket *packet); + +int MqttPacketHasId(const MqttPacket *packet); + +#endif diff --git a/src/queue.h b/src/queue.h new file mode 100644 index 0000000..dddb466 --- /dev/null +++ b/src/queue.h @@ -0,0 +1,846 @@ +/* $NetBSD: queue.h,v 1.68.2.1 2015/12/27 12:10:18 skrll Exp $ */ + +/* + * Copyright (c) 1991, 1993 + * The Regents of the University of California. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * @(#)queue.h 8.5 (Berkeley) 8/20/94 + */ + +#ifndef _SYS_QUEUE_H_ +#define _SYS_QUEUE_H_ + +/* + * This file defines five types of data structures: singly-linked lists, + * lists, simple queues, tail queues, and circular queues. + * + * A singly-linked list is headed by a single forward pointer. The + * elements are singly linked for minimum space and pointer manipulation + * overhead at the expense of O(n) removal for arbitrary elements. New + * elements can be added to the list after an existing element or at the + * head of the list. Elements being removed from the head of the list + * should use the explicit macro for this purpose for optimum + * efficiency. A singly-linked list may only be traversed in the forward + * direction. Singly-linked lists are ideal for applications with large + * datasets and few or no removals or for implementing a LIFO queue. + * + * A list is headed by a single forward pointer (or an array of forward + * pointers for a hash table header). The elements are doubly linked + * so that an arbitrary element can be removed without a need to + * traverse the list. New elements can be added to the list before + * or after an existing element or at the head of the list. A list + * may only be traversed in the forward direction. + * + * A simple queue is headed by a pair of pointers, one the head of the + * list and the other to the tail of the list. The elements are singly + * linked to save space, so elements can only be removed from the + * head of the list. New elements can be added to the list after + * an existing element, at the head of the list, or at the end of the + * list. A simple queue may only be traversed in the forward direction. + * + * A tail queue is headed by a pair of pointers, one to the head of the + * list and the other to the tail of the list. The elements are doubly + * linked so that an arbitrary element can be removed without a need to + * traverse the list. New elements can be added to the list before or + * after an existing element, at the head of the list, or at the end of + * the list. A tail queue may be traversed in either direction. + * + * A circle queue is headed by a pair of pointers, one to the head of the + * list and the other to the tail of the list. The elements are doubly + * linked so that an arbitrary element can be removed without a need to + * traverse the list. New elements can be added to the list before or after + * an existing element, at the head of the list, or at the end of the list. + * A circle queue may be traversed in either direction, but has a more + * complex end of list detection. + * + * For details on the use of these macros, see the queue(3) manual page. + */ + +/* + * Include the definition of NULL only on NetBSD because sys/null.h + * is not available elsewhere. This conditional makes the header + * portable and it can simply be dropped verbatim into any system. + * The caveat is that on other systems some other header + * must provide NULL before the macros can be used. + */ +#ifdef __NetBSD__ +#include <sys/null.h> +#endif + +#if defined(QUEUEDEBUG) +# if defined(_KERNEL) +# define QUEUEDEBUG_ABORT(...) panic(__VA_ARGS__) +# else +# include <err.h> +# define QUEUEDEBUG_ABORT(...) err(1, __VA_ARGS__) +# endif +#endif + +/* + * Singly-linked List definitions. + */ +#define SLIST_HEAD(name, type) \ +struct name { \ + struct type *slh_first; /* first element */ \ +} + +#define SLIST_HEAD_INITIALIZER(head) \ + { NULL } + +#define SLIST_ENTRY(type) \ +struct { \ + struct type *sle_next; /* next element */ \ +} + +/* + * Singly-linked List access methods. + */ +#define SLIST_FIRST(head) ((head)->slh_first) +#define SLIST_END(head) NULL +#define SLIST_EMPTY(head) ((head)->slh_first == NULL) +#define SLIST_NEXT(elm, field) ((elm)->field.sle_next) + +#define SLIST_FOREACH(var, head, field) \ + for((var) = (head)->slh_first; \ + (var) != SLIST_END(head); \ + (var) = (var)->field.sle_next) + +#define SLIST_FOREACH_SAFE(var, head, field, tvar) \ + for ((var) = SLIST_FIRST((head)); \ + (var) != SLIST_END(head) && \ + ((tvar) = SLIST_NEXT((var), field), 1); \ + (var) = (tvar)) + +/* + * Singly-linked List functions. + */ +#define SLIST_INIT(head) do { \ + (head)->slh_first = SLIST_END(head); \ +} while (/*CONSTCOND*/0) + +#define SLIST_INSERT_AFTER(slistelm, elm, field) do { \ + (elm)->field.sle_next = (slistelm)->field.sle_next; \ + (slistelm)->field.sle_next = (elm); \ +} while (/*CONSTCOND*/0) + +#define SLIST_INSERT_HEAD(head, elm, field) do { \ + (elm)->field.sle_next = (head)->slh_first; \ + (head)->slh_first = (elm); \ +} while (/*CONSTCOND*/0) + +#define SLIST_REMOVE_AFTER(slistelm, field) do { \ + (slistelm)->field.sle_next = \ + SLIST_NEXT(SLIST_NEXT((slistelm), field), field); \ +} while (/*CONSTCOND*/0) + +#define SLIST_REMOVE_HEAD(head, field) do { \ + (head)->slh_first = (head)->slh_first->field.sle_next; \ +} while (/*CONSTCOND*/0) + +#define SLIST_REMOVE(head, elm, type, field) do { \ + if ((head)->slh_first == (elm)) { \ + SLIST_REMOVE_HEAD((head), field); \ + } \ + else { \ + struct type *curelm = (head)->slh_first; \ + while(curelm->field.sle_next != (elm)) \ + curelm = curelm->field.sle_next; \ + curelm->field.sle_next = \ + curelm->field.sle_next->field.sle_next; \ + } \ +} while (/*CONSTCOND*/0) + + +/* + * List definitions. + */ +#define LIST_HEAD(name, type) \ +struct name { \ + struct type *lh_first; /* first element */ \ +} + +#define LIST_HEAD_INITIALIZER(head) \ + { NULL } + +#define LIST_ENTRY(type) \ +struct { \ + struct type *le_next; /* next element */ \ + struct type **le_prev; /* address of previous next element */ \ +} + +/* + * List access methods. + */ +#define LIST_FIRST(head) ((head)->lh_first) +#define LIST_END(head) NULL +#define LIST_EMPTY(head) ((head)->lh_first == LIST_END(head)) +#define LIST_NEXT(elm, field) ((elm)->field.le_next) + +#define LIST_FOREACH(var, head, field) \ + for ((var) = ((head)->lh_first); \ + (var) != LIST_END(head); \ + (var) = ((var)->field.le_next)) + +#define LIST_FOREACH_SAFE(var, head, field, tvar) \ + for ((var) = LIST_FIRST((head)); \ + (var) != LIST_END(head) && \ + ((tvar) = LIST_NEXT((var), field), 1); \ + (var) = (tvar)) + +#define LIST_MOVE(head1, head2) do { \ + LIST_INIT((head2)); \ + if (!LIST_EMPTY((head1))) { \ + (head2)->lh_first = (head1)->lh_first; \ + LIST_INIT((head1)); \ + } \ +} while (/*CONSTCOND*/0) + +/* + * List functions. + */ +#if defined(QUEUEDEBUG) +#define QUEUEDEBUG_LIST_INSERT_HEAD(head, elm, field) \ + if ((head)->lh_first && \ + (head)->lh_first->field.le_prev != &(head)->lh_first) \ + QUEUEDEBUG_ABORT("LIST_INSERT_HEAD %p %s:%d", (head), \ + __FILE__, __LINE__); +#define QUEUEDEBUG_LIST_OP(elm, field) \ + if ((elm)->field.le_next && \ + (elm)->field.le_next->field.le_prev != \ + &(elm)->field.le_next) \ + QUEUEDEBUG_ABORT("LIST_* forw %p %s:%d", (elm), \ + __FILE__, __LINE__); \ + if (*(elm)->field.le_prev != (elm)) \ + QUEUEDEBUG_ABORT("LIST_* back %p %s:%d", (elm), \ + __FILE__, __LINE__); +#define QUEUEDEBUG_LIST_POSTREMOVE(elm, field) \ + (elm)->field.le_next = (void *)1L; \ + (elm)->field.le_prev = (void *)1L; +#else +#define QUEUEDEBUG_LIST_INSERT_HEAD(head, elm, field) +#define QUEUEDEBUG_LIST_OP(elm, field) +#define QUEUEDEBUG_LIST_POSTREMOVE(elm, field) +#endif + +#define LIST_INIT(head) do { \ + (head)->lh_first = LIST_END(head); \ +} while (/*CONSTCOND*/0) + +#define LIST_INSERT_AFTER(listelm, elm, field) do { \ + QUEUEDEBUG_LIST_OP((listelm), field) \ + if (((elm)->field.le_next = (listelm)->field.le_next) != \ + LIST_END(head)) \ + (listelm)->field.le_next->field.le_prev = \ + &(elm)->field.le_next; \ + (listelm)->field.le_next = (elm); \ + (elm)->field.le_prev = &(listelm)->field.le_next; \ +} while (/*CONSTCOND*/0) + +#define LIST_INSERT_BEFORE(listelm, elm, field) do { \ + QUEUEDEBUG_LIST_OP((listelm), field) \ + (elm)->field.le_prev = (listelm)->field.le_prev; \ + (elm)->field.le_next = (listelm); \ + *(listelm)->field.le_prev = (elm); \ + (listelm)->field.le_prev = &(elm)->field.le_next; \ +} while (/*CONSTCOND*/0) + +#define LIST_INSERT_HEAD(head, elm, field) do { \ + QUEUEDEBUG_LIST_INSERT_HEAD((head), (elm), field) \ + if (((elm)->field.le_next = (head)->lh_first) != LIST_END(head))\ + (head)->lh_first->field.le_prev = &(elm)->field.le_next;\ + (head)->lh_first = (elm); \ + (elm)->field.le_prev = &(head)->lh_first; \ +} while (/*CONSTCOND*/0) + +#define LIST_REMOVE(elm, field) do { \ + QUEUEDEBUG_LIST_OP((elm), field) \ + if ((elm)->field.le_next != NULL) \ + (elm)->field.le_next->field.le_prev = \ + (elm)->field.le_prev; \ + *(elm)->field.le_prev = (elm)->field.le_next; \ + QUEUEDEBUG_LIST_POSTREMOVE((elm), field) \ +} while (/*CONSTCOND*/0) + +#define LIST_REPLACE(elm, elm2, field) do { \ + if (((elm2)->field.le_next = (elm)->field.le_next) != NULL) \ + (elm2)->field.le_next->field.le_prev = \ + &(elm2)->field.le_next; \ + (elm2)->field.le_prev = (elm)->field.le_prev; \ + *(elm2)->field.le_prev = (elm2); \ + QUEUEDEBUG_LIST_POSTREMOVE((elm), field) \ +} while (/*CONSTCOND*/0) + +/* + * Simple queue definitions. + */ +#define SIMPLEQ_HEAD(name, type) \ +struct name { \ + struct type *sqh_first; /* first element */ \ + struct type **sqh_last; /* addr of last next element */ \ +} + +#define SIMPLEQ_HEAD_INITIALIZER(head) \ + { NULL, &(head).sqh_first } + +#define SIMPLEQ_ENTRY(type) \ +struct { \ + struct type *sqe_next; /* next element */ \ +} + +/* + * Simple queue access methods. + */ +#define SIMPLEQ_FIRST(head) ((head)->sqh_first) +#define SIMPLEQ_END(head) NULL +#define SIMPLEQ_EMPTY(head) ((head)->sqh_first == SIMPLEQ_END(head)) +#define SIMPLEQ_NEXT(elm, field) ((elm)->field.sqe_next) + +#define SIMPLEQ_FOREACH(var, head, field) \ + for ((var) = ((head)->sqh_first); \ + (var) != SIMPLEQ_END(head); \ + (var) = ((var)->field.sqe_next)) + +#define SIMPLEQ_FOREACH_SAFE(var, head, field, next) \ + for ((var) = ((head)->sqh_first); \ + (var) != SIMPLEQ_END(head) && \ + ((next = ((var)->field.sqe_next)), 1); \ + (var) = (next)) + +/* + * Simple queue functions. + */ +#define SIMPLEQ_INIT(head) do { \ + (head)->sqh_first = NULL; \ + (head)->sqh_last = &(head)->sqh_first; \ +} while (/*CONSTCOND*/0) + +#define SIMPLEQ_INSERT_HEAD(head, elm, field) do { \ + if (((elm)->field.sqe_next = (head)->sqh_first) == NULL) \ + (head)->sqh_last = &(elm)->field.sqe_next; \ + (head)->sqh_first = (elm); \ +} while (/*CONSTCOND*/0) + +#define SIMPLEQ_INSERT_TAIL(head, elm, field) do { \ + (elm)->field.sqe_next = NULL; \ + *(head)->sqh_last = (elm); \ + (head)->sqh_last = &(elm)->field.sqe_next; \ +} while (/*CONSTCOND*/0) + +#define SIMPLEQ_INSERT_AFTER(head, listelm, elm, field) do { \ + if (((elm)->field.sqe_next = (listelm)->field.sqe_next) == NULL)\ + (head)->sqh_last = &(elm)->field.sqe_next; \ + (listelm)->field.sqe_next = (elm); \ +} while (/*CONSTCOND*/0) + +#define SIMPLEQ_REMOVE_HEAD(head, field) do { \ + if (((head)->sqh_first = (head)->sqh_first->field.sqe_next) == NULL) \ + (head)->sqh_last = &(head)->sqh_first; \ +} while (/*CONSTCOND*/0) + +#define SIMPLEQ_REMOVE_AFTER(head, elm, field) do { \ + if (((elm)->field.sqe_next = (elm)->field.sqe_next->field.sqe_next) \ + == NULL) \ + (head)->sqh_last = &(elm)->field.sqe_next; \ +} while (/*CONSTCOND*/0) + +#define SIMPLEQ_REMOVE(head, elm, type, field) do { \ + if ((head)->sqh_first == (elm)) { \ + SIMPLEQ_REMOVE_HEAD((head), field); \ + } else { \ + struct type *curelm = (head)->sqh_first; \ + while (curelm->field.sqe_next != (elm)) \ + curelm = curelm->field.sqe_next; \ + if ((curelm->field.sqe_next = \ + curelm->field.sqe_next->field.sqe_next) == NULL) \ + (head)->sqh_last = &(curelm)->field.sqe_next; \ + } \ +} while (/*CONSTCOND*/0) + +#define SIMPLEQ_CONCAT(head1, head2) do { \ + if (!SIMPLEQ_EMPTY((head2))) { \ + *(head1)->sqh_last = (head2)->sqh_first; \ + (head1)->sqh_last = (head2)->sqh_last; \ + SIMPLEQ_INIT((head2)); \ + } \ +} while (/*CONSTCOND*/0) + +#define SIMPLEQ_LAST(head, type, field) \ + (SIMPLEQ_EMPTY((head)) ? \ + NULL : \ + ((struct type *)(void *) \ + ((char *)((head)->sqh_last) - offsetof(struct type, field)))) + +/* + * Tail queue definitions. + */ +#define _TAILQ_HEAD(name, type, qual) \ +struct name { \ + qual type *tqh_first; /* first element */ \ + qual type *qual *tqh_last; /* addr of last next element */ \ +} +#define TAILQ_HEAD(name, type) _TAILQ_HEAD(name, struct type,) + +#define TAILQ_HEAD_INITIALIZER(head) \ + { TAILQ_END(head), &(head).tqh_first } + +#define _TAILQ_ENTRY(type, qual) \ +struct { \ + qual type *tqe_next; /* next element */ \ + qual type *qual *tqe_prev; /* address of previous next element */\ +} +#define TAILQ_ENTRY(type) _TAILQ_ENTRY(struct type,) + +/* + * Tail queue access methods. + */ +#define TAILQ_FIRST(head) ((head)->tqh_first) +#define TAILQ_END(head) (NULL) +#define TAILQ_NEXT(elm, field) ((elm)->field.tqe_next) +#define TAILQ_LAST(head, headname) \ + (*(((struct headname *)(void *)((head)->tqh_last))->tqh_last)) +#define TAILQ_PREV(elm, headname, field) \ + (*(((struct headname *)(void *)((elm)->field.tqe_prev))->tqh_last)) +#define TAILQ_EMPTY(head) (TAILQ_FIRST(head) == TAILQ_END(head)) + + +#define TAILQ_FOREACH(var, head, field) \ + for ((var) = ((head)->tqh_first); \ + (var) != TAILQ_END(head); \ + (var) = ((var)->field.tqe_next)) + +#define TAILQ_FOREACH_SAFE(var, head, field, next) \ + for ((var) = ((head)->tqh_first); \ + (var) != TAILQ_END(head) && \ + ((next) = TAILQ_NEXT(var, field), 1); (var) = (next)) + +#define TAILQ_FOREACH_REVERSE(var, head, headname, field) \ + for ((var) = TAILQ_LAST((head), headname); \ + (var) != TAILQ_END(head); \ + (var) = TAILQ_PREV((var), headname, field)) + +#define TAILQ_FOREACH_REVERSE_SAFE(var, head, headname, field, prev) \ + for ((var) = TAILQ_LAST((head), headname); \ + (var) != TAILQ_END(head) && \ + ((prev) = TAILQ_PREV((var), headname, field), 1); (var) = (prev)) + +/* + * Tail queue functions. + */ +#if defined(QUEUEDEBUG) +#define QUEUEDEBUG_TAILQ_INSERT_HEAD(head, elm, field) \ + if ((head)->tqh_first && \ + (head)->tqh_first->field.tqe_prev != &(head)->tqh_first) \ + QUEUEDEBUG_ABORT("TAILQ_INSERT_HEAD %p %s:%d", (head), \ + __FILE__, __LINE__); +#define QUEUEDEBUG_TAILQ_INSERT_TAIL(head, elm, field) \ + if (*(head)->tqh_last != NULL) \ + QUEUEDEBUG_ABORT("TAILQ_INSERT_TAIL %p %s:%d", (head), \ + __FILE__, __LINE__); +#define QUEUEDEBUG_TAILQ_OP(elm, field) \ + if ((elm)->field.tqe_next && \ + (elm)->field.tqe_next->field.tqe_prev != \ + &(elm)->field.tqe_next) \ + QUEUEDEBUG_ABORT("TAILQ_* forw %p %s:%d", (elm), \ + __FILE__, __LINE__); \ + if (*(elm)->field.tqe_prev != (elm)) \ + QUEUEDEBUG_ABORT("TAILQ_* back %p %s:%d", (elm), \ + __FILE__, __LINE__); +#define QUEUEDEBUG_TAILQ_PREREMOVE(head, elm, field) \ + if ((elm)->field.tqe_next == NULL && \ + (head)->tqh_last != &(elm)->field.tqe_next) \ + QUEUEDEBUG_ABORT("TAILQ_PREREMOVE head %p elm %p %s:%d",\ + (head), (elm), __FILE__, __LINE__); +#define QUEUEDEBUG_TAILQ_POSTREMOVE(elm, field) \ + (elm)->field.tqe_next = (void *)1L; \ + (elm)->field.tqe_prev = (void *)1L; +#else +#define QUEUEDEBUG_TAILQ_INSERT_HEAD(head, elm, field) +#define QUEUEDEBUG_TAILQ_INSERT_TAIL(head, elm, field) +#define QUEUEDEBUG_TAILQ_OP(elm, field) +#define QUEUEDEBUG_TAILQ_PREREMOVE(head, elm, field) +#define QUEUEDEBUG_TAILQ_POSTREMOVE(elm, field) +#endif + +#define TAILQ_INIT(head) do { \ + (head)->tqh_first = TAILQ_END(head); \ + (head)->tqh_last = &(head)->tqh_first; \ +} while (/*CONSTCOND*/0) + +#define TAILQ_INSERT_HEAD(head, elm, field) do { \ + QUEUEDEBUG_TAILQ_INSERT_HEAD((head), (elm), field) \ + if (((elm)->field.tqe_next = (head)->tqh_first) != TAILQ_END(head))\ + (head)->tqh_first->field.tqe_prev = \ + &(elm)->field.tqe_next; \ + else \ + (head)->tqh_last = &(elm)->field.tqe_next; \ + (head)->tqh_first = (elm); \ + (elm)->field.tqe_prev = &(head)->tqh_first; \ +} while (/*CONSTCOND*/0) + +#define TAILQ_INSERT_TAIL(head, elm, field) do { \ + QUEUEDEBUG_TAILQ_INSERT_TAIL((head), (elm), field) \ + (elm)->field.tqe_next = TAILQ_END(head); \ + (elm)->field.tqe_prev = (head)->tqh_last; \ + *(head)->tqh_last = (elm); \ + (head)->tqh_last = &(elm)->field.tqe_next; \ +} while (/*CONSTCOND*/0) + +#define TAILQ_INSERT_AFTER(head, listelm, elm, field) do { \ + QUEUEDEBUG_TAILQ_OP((listelm), field) \ + if (((elm)->field.tqe_next = (listelm)->field.tqe_next) != \ + TAILQ_END(head)) \ + (elm)->field.tqe_next->field.tqe_prev = \ + &(elm)->field.tqe_next; \ + else \ + (head)->tqh_last = &(elm)->field.tqe_next; \ + (listelm)->field.tqe_next = (elm); \ + (elm)->field.tqe_prev = &(listelm)->field.tqe_next; \ +} while (/*CONSTCOND*/0) + +#define TAILQ_INSERT_BEFORE(listelm, elm, field) do { \ + QUEUEDEBUG_TAILQ_OP((listelm), field) \ + (elm)->field.tqe_prev = (listelm)->field.tqe_prev; \ + (elm)->field.tqe_next = (listelm); \ + *(listelm)->field.tqe_prev = (elm); \ + (listelm)->field.tqe_prev = &(elm)->field.tqe_next; \ +} while (/*CONSTCOND*/0) + +#define TAILQ_REMOVE(head, elm, field) do { \ + QUEUEDEBUG_TAILQ_PREREMOVE((head), (elm), field) \ + QUEUEDEBUG_TAILQ_OP((elm), field) \ + if (((elm)->field.tqe_next) != TAILQ_END(head)) \ + (elm)->field.tqe_next->field.tqe_prev = \ + (elm)->field.tqe_prev; \ + else \ + (head)->tqh_last = (elm)->field.tqe_prev; \ + *(elm)->field.tqe_prev = (elm)->field.tqe_next; \ + QUEUEDEBUG_TAILQ_POSTREMOVE((elm), field); \ +} while (/*CONSTCOND*/0) + +#define TAILQ_REPLACE(head, elm, elm2, field) do { \ + if (((elm2)->field.tqe_next = (elm)->field.tqe_next) != \ + TAILQ_END(head)) \ + (elm2)->field.tqe_next->field.tqe_prev = \ + &(elm2)->field.tqe_next; \ + else \ + (head)->tqh_last = &(elm2)->field.tqe_next; \ + (elm2)->field.tqe_prev = (elm)->field.tqe_prev; \ + *(elm2)->field.tqe_prev = (elm2); \ + QUEUEDEBUG_TAILQ_POSTREMOVE((elm), field); \ +} while (/*CONSTCOND*/0) + +#define TAILQ_CONCAT(head1, head2, field) do { \ + if (!TAILQ_EMPTY(head2)) { \ + *(head1)->tqh_last = (head2)->tqh_first; \ + (head2)->tqh_first->field.tqe_prev = (head1)->tqh_last; \ + (head1)->tqh_last = (head2)->tqh_last; \ + TAILQ_INIT((head2)); \ + } \ +} while (/*CONSTCOND*/0) + +/* + * Singly-linked Tail queue declarations. + */ +#define STAILQ_HEAD(name, type) \ +struct name { \ + struct type *stqh_first; /* first element */ \ + struct type **stqh_last; /* addr of last next element */ \ +} + +#define STAILQ_HEAD_INITIALIZER(head) \ + { NULL, &(head).stqh_first } + +#define STAILQ_ENTRY(type) \ +struct { \ + struct type *stqe_next; /* next element */ \ +} + +/* + * Singly-linked Tail queue access methods. + */ +#define STAILQ_FIRST(head) ((head)->stqh_first) +#define STAILQ_END(head) NULL +#define STAILQ_NEXT(elm, field) ((elm)->field.stqe_next) +#define STAILQ_EMPTY(head) (STAILQ_FIRST(head) == STAILQ_END(head)) + +/* + * Singly-linked Tail queue functions. + */ +#define STAILQ_INIT(head) do { \ + (head)->stqh_first = NULL; \ + (head)->stqh_last = &(head)->stqh_first; \ +} while (/*CONSTCOND*/0) + +#define STAILQ_INSERT_HEAD(head, elm, field) do { \ + if (((elm)->field.stqe_next = (head)->stqh_first) == NULL) \ + (head)->stqh_last = &(elm)->field.stqe_next; \ + (head)->stqh_first = (elm); \ +} while (/*CONSTCOND*/0) + +#define STAILQ_INSERT_TAIL(head, elm, field) do { \ + (elm)->field.stqe_next = NULL; \ + *(head)->stqh_last = (elm); \ + (head)->stqh_last = &(elm)->field.stqe_next; \ +} while (/*CONSTCOND*/0) + +#define STAILQ_INSERT_AFTER(head, listelm, elm, field) do { \ + if (((elm)->field.stqe_next = (listelm)->field.stqe_next) == NULL)\ + (head)->stqh_last = &(elm)->field.stqe_next; \ + (listelm)->field.stqe_next = (elm); \ +} while (/*CONSTCOND*/0) + +#define STAILQ_REMOVE_HEAD(head, field) do { \ + if (((head)->stqh_first = (head)->stqh_first->field.stqe_next) == NULL) \ + (head)->stqh_last = &(head)->stqh_first; \ +} while (/*CONSTCOND*/0) + +#define STAILQ_REMOVE(head, elm, type, field) do { \ + if ((head)->stqh_first == (elm)) { \ + STAILQ_REMOVE_HEAD((head), field); \ + } else { \ + struct type *curelm = (head)->stqh_first; \ + while (curelm->field.stqe_next != (elm)) \ + curelm = curelm->field.stqe_next; \ + if ((curelm->field.stqe_next = \ + curelm->field.stqe_next->field.stqe_next) == NULL) \ + (head)->stqh_last = &(curelm)->field.stqe_next; \ + } \ +} while (/*CONSTCOND*/0) + +#define STAILQ_FOREACH(var, head, field) \ + for ((var) = ((head)->stqh_first); \ + (var); \ + (var) = ((var)->field.stqe_next)) + +#define STAILQ_FOREACH_SAFE(var, head, field, tvar) \ + for ((var) = STAILQ_FIRST((head)); \ + (var) && ((tvar) = STAILQ_NEXT((var), field), 1); \ + (var) = (tvar)) + +#define STAILQ_CONCAT(head1, head2) do { \ + if (!STAILQ_EMPTY((head2))) { \ + *(head1)->stqh_last = (head2)->stqh_first; \ + (head1)->stqh_last = (head2)->stqh_last; \ + STAILQ_INIT((head2)); \ + } \ +} while (/*CONSTCOND*/0) + +#define STAILQ_LAST(head, type, field) \ + (STAILQ_EMPTY((head)) ? \ + NULL : \ + ((struct type *)(void *) \ + ((char *)((head)->stqh_last) - offsetof(struct type, field)))) + + +#ifndef _KERNEL +/* + * Circular queue definitions. Do not use. We still keep the macros + * for compatibility but because of pointer aliasing issues their use + * is discouraged! + */ + +/* + * __launder_type(): We use this ugly hack to work around the the compiler + * noticing that two types may not alias each other and elide tests in code. + * We hit this in the CIRCLEQ macros when comparing 'struct name *' and + * 'struct type *' (see CIRCLEQ_HEAD()). Modern compilers (such as GCC + * 4.8) declare these comparisons as always false, causing the code to + * not run as designed. + * + * This hack is only to be used for comparisons and thus can be fully const. + * Do not use for assignment. + * + * If we ever choose to change the ABI of the CIRCLEQ macros, we could fix + * this by changing the head/tail sentinal values, but see the note above + * this one. + */ +static __inline const void * __launder_type(const void *); +static __inline const void * +__launder_type(const void *__x) +{ + __asm __volatile("" : "+r" (__x)); + return __x; +} + +#if defined(QUEUEDEBUG) +#define QUEUEDEBUG_CIRCLEQ_HEAD(head, field) \ + if ((head)->cqh_first != CIRCLEQ_ENDC(head) && \ + (head)->cqh_first->field.cqe_prev != CIRCLEQ_ENDC(head)) \ + QUEUEDEBUG_ABORT("CIRCLEQ head forw %p %s:%d", (head), \ + __FILE__, __LINE__); \ + if ((head)->cqh_last != CIRCLEQ_ENDC(head) && \ + (head)->cqh_last->field.cqe_next != CIRCLEQ_ENDC(head)) \ + QUEUEDEBUG_ABORT("CIRCLEQ head back %p %s:%d", (head), \ + __FILE__, __LINE__); +#define QUEUEDEBUG_CIRCLEQ_ELM(head, elm, field) \ + if ((elm)->field.cqe_next == CIRCLEQ_ENDC(head)) { \ + if ((head)->cqh_last != (elm)) \ + QUEUEDEBUG_ABORT("CIRCLEQ elm last %p %s:%d", \ + (elm), __FILE__, __LINE__); \ + } else { \ + if ((elm)->field.cqe_next->field.cqe_prev != (elm)) \ + QUEUEDEBUG_ABORT("CIRCLEQ elm forw %p %s:%d", \ + (elm), __FILE__, __LINE__); \ + } \ + if ((elm)->field.cqe_prev == CIRCLEQ_ENDC(head)) { \ + if ((head)->cqh_first != (elm)) \ + QUEUEDEBUG_ABORT("CIRCLEQ elm first %p %s:%d", \ + (elm), __FILE__, __LINE__); \ + } else { \ + if ((elm)->field.cqe_prev->field.cqe_next != (elm)) \ + QUEUEDEBUG_ABORT("CIRCLEQ elm prev %p %s:%d", \ + (elm), __FILE__, __LINE__); \ + } +#define QUEUEDEBUG_CIRCLEQ_POSTREMOVE(elm, field) \ + (elm)->field.cqe_next = (void *)1L; \ + (elm)->field.cqe_prev = (void *)1L; +#else +#define QUEUEDEBUG_CIRCLEQ_HEAD(head, field) +#define QUEUEDEBUG_CIRCLEQ_ELM(head, elm, field) +#define QUEUEDEBUG_CIRCLEQ_POSTREMOVE(elm, field) +#endif + +#define CIRCLEQ_HEAD(name, type) \ +struct name { \ + struct type *cqh_first; /* first element */ \ + struct type *cqh_last; /* last element */ \ +} + +#define CIRCLEQ_HEAD_INITIALIZER(head) \ + { CIRCLEQ_END(&head), CIRCLEQ_END(&head) } + +#define CIRCLEQ_ENTRY(type) \ +struct { \ + struct type *cqe_next; /* next element */ \ + struct type *cqe_prev; /* previous element */ \ +} + +/* + * Circular queue functions. + */ +#define CIRCLEQ_INIT(head) do { \ + (head)->cqh_first = CIRCLEQ_END(head); \ + (head)->cqh_last = CIRCLEQ_END(head); \ +} while (/*CONSTCOND*/0) + +#define CIRCLEQ_INSERT_AFTER(head, listelm, elm, field) do { \ + QUEUEDEBUG_CIRCLEQ_HEAD((head), field) \ + QUEUEDEBUG_CIRCLEQ_ELM((head), (listelm), field) \ + (elm)->field.cqe_next = (listelm)->field.cqe_next; \ + (elm)->field.cqe_prev = (listelm); \ + if ((listelm)->field.cqe_next == CIRCLEQ_ENDC(head)) \ + (head)->cqh_last = (elm); \ + else \ + (listelm)->field.cqe_next->field.cqe_prev = (elm); \ + (listelm)->field.cqe_next = (elm); \ +} while (/*CONSTCOND*/0) + +#define CIRCLEQ_INSERT_BEFORE(head, listelm, elm, field) do { \ + QUEUEDEBUG_CIRCLEQ_HEAD((head), field) \ + QUEUEDEBUG_CIRCLEQ_ELM((head), (listelm), field) \ + (elm)->field.cqe_next = (listelm); \ + (elm)->field.cqe_prev = (listelm)->field.cqe_prev; \ + if ((listelm)->field.cqe_prev == CIRCLEQ_ENDC(head)) \ + (head)->cqh_first = (elm); \ + else \ + (listelm)->field.cqe_prev->field.cqe_next = (elm); \ + (listelm)->field.cqe_prev = (elm); \ +} while (/*CONSTCOND*/0) + +#define CIRCLEQ_INSERT_HEAD(head, elm, field) do { \ + QUEUEDEBUG_CIRCLEQ_HEAD((head), field) \ + (elm)->field.cqe_next = (head)->cqh_first; \ + (elm)->field.cqe_prev = CIRCLEQ_END(head); \ + if ((head)->cqh_last == CIRCLEQ_ENDC(head)) \ + (head)->cqh_last = (elm); \ + else \ + (head)->cqh_first->field.cqe_prev = (elm); \ + (head)->cqh_first = (elm); \ +} while (/*CONSTCOND*/0) + +#define CIRCLEQ_INSERT_TAIL(head, elm, field) do { \ + QUEUEDEBUG_CIRCLEQ_HEAD((head), field) \ + (elm)->field.cqe_next = CIRCLEQ_END(head); \ + (elm)->field.cqe_prev = (head)->cqh_last; \ + if ((head)->cqh_first == CIRCLEQ_ENDC(head)) \ + (head)->cqh_first = (elm); \ + else \ + (head)->cqh_last->field.cqe_next = (elm); \ + (head)->cqh_last = (elm); \ +} while (/*CONSTCOND*/0) + +#define CIRCLEQ_REMOVE(head, elm, field) do { \ + QUEUEDEBUG_CIRCLEQ_HEAD((head), field) \ + QUEUEDEBUG_CIRCLEQ_ELM((head), (elm), field) \ + if ((elm)->field.cqe_next == CIRCLEQ_ENDC(head)) \ + (head)->cqh_last = (elm)->field.cqe_prev; \ + else \ + (elm)->field.cqe_next->field.cqe_prev = \ + (elm)->field.cqe_prev; \ + if ((elm)->field.cqe_prev == CIRCLEQ_ENDC(head)) \ + (head)->cqh_first = (elm)->field.cqe_next; \ + else \ + (elm)->field.cqe_prev->field.cqe_next = \ + (elm)->field.cqe_next; \ + QUEUEDEBUG_CIRCLEQ_POSTREMOVE((elm), field) \ +} while (/*CONSTCOND*/0) + +#define CIRCLEQ_FOREACH(var, head, field) \ + for ((var) = ((head)->cqh_first); \ + (var) != CIRCLEQ_ENDC(head); \ + (var) = ((var)->field.cqe_next)) + +#define CIRCLEQ_FOREACH_REVERSE(var, head, field) \ + for ((var) = ((head)->cqh_last); \ + (var) != CIRCLEQ_ENDC(head); \ + (var) = ((var)->field.cqe_prev)) + +/* + * Circular queue access methods. + */ +#define CIRCLEQ_FIRST(head) ((head)->cqh_first) +#define CIRCLEQ_LAST(head) ((head)->cqh_last) +/* For comparisons */ +#define CIRCLEQ_ENDC(head) (__launder_type(head)) +/* For assignments */ +#define CIRCLEQ_END(head) ((void *)(head)) +#define CIRCLEQ_NEXT(elm, field) ((elm)->field.cqe_next) +#define CIRCLEQ_PREV(elm, field) ((elm)->field.cqe_prev) +#define CIRCLEQ_EMPTY(head) \ + (CIRCLEQ_FIRST(head) == CIRCLEQ_ENDC(head)) + +#define CIRCLEQ_LOOP_NEXT(head, elm, field) \ + (((elm)->field.cqe_next == CIRCLEQ_ENDC(head)) \ + ? ((head)->cqh_first) \ + : (elm->field.cqe_next)) +#define CIRCLEQ_LOOP_PREV(head, elm, field) \ + (((elm)->field.cqe_prev == CIRCLEQ_ENDC(head)) \ + ? ((head)->cqh_last) \ + : (elm->field.cqe_prev)) +#endif /* !_KERNEL */ + +#endif /* !_SYS_QUEUE_H_ */ diff --git a/src/serialize.c b/src/serialize.c new file mode 100644 index 0000000..d14cf03 --- /dev/null +++ b/src/serialize.c @@ -0,0 +1,309 @@ +#include "serialize.h" +#include "stringbuf.h" +#include "packet.h" +#include "stream_mqtt.h" +#include "log.h" + +#include <stdlib.h> +#include <assert.h> + +typedef int (*MqttPacketSerializeFunc)(const MqttPacket *packet, + Stream *stream); + +static const StringBuf MqttProtocolId = StaticStringBuf("MQTT"); +static const char MqttProtocolLevel = 0x04; + +static inline size_t MqttStringLengthSerialized(const StringBuf *s) +{ + return 2 + s->len; +} + +static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packet) +{ + size_t remainingLength = 0; + + remainingLength += MqttStringLengthSerialized(&MqttProtocolId) + 1 + 1 + 2; + + remainingLength += MqttStringLengthSerialized(&packet->clientId); + + if (packet->connectFlags & 0x80) + remainingLength += MqttStringLengthSerialized(&packet->userName); + + if (packet->connectFlags & 0x40) + remainingLength += MqttStringLengthSerialized(&packet->password); + + if (packet->connectFlags & 0x04) + remainingLength += MqttStringLengthSerialized(&packet->willTopic) + + MqttStringLengthSerialized(&packet->willMessage); + + return remainingLength; +} + +static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet) +{ + return 2 + MqttStringLengthSerialized(&packet->topicFilter) + 1; +} + +static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet) +{ + return 2 + MqttStringLengthSerialized(&packet->topicFilter); +} + +static size_t MqttPacketPublishGetRemainingLength(const MqttPacketPublish *packet) +{ + size_t remainingLength = 0; + + remainingLength += MqttStringLengthSerialized(&packet->topicName); + + // Packet id + if (MqttPacketPublishQos(packet) == 1 || MqttPacketPublishQos(packet) == 2) + { + remainingLength += 2; + } + + remainingLength += packet->message.len; + + return remainingLength; +} + +static size_t MqttPacketGetRemainingLength(const MqttPacket *packet) +{ + switch (packet->type) + { + case MqttPacketTypeConnect: + return MqttPacketConnectGetRemainingLength( + (MqttPacketConnect *) packet); + + case MqttPacketTypeSubscribe: + return MqttPacketSubscribeGetRemainingLength( + (MqttPacketSubscribe *) packet); + + case MqttPacketTypePublish: + return MqttPacketPublishGetRemainingLength( + (MqttPacketPublish *) packet); + + case MqttPacketTypePubAck: + case MqttPacketTypePubRec: + case MqttPacketTypePubRel: + case MqttPacketTypePubComp: + return 2; + + case MqttPacketTypeUnsubscribe: + return MqttPacketUnsubscribeGetRemainingLength( + (MqttPacketUnsubscribe *) packet); + + default: + return 0; + } +} + +static int MqttPacketFlags(const MqttPacket *packet) +{ + switch (packet->type) + { + case MqttPacketTypePublish: + return ((MqttPacketPublishDup(packet) & 1) << 3) | + ((MqttPacketPublishQos(packet) & 3) << 1) | + (MqttPacketPublishRetain(packet) & 1); + + case MqttPacketTypePubRel: + case MqttPacketTypeSubscribe: + case MqttPacketTypeUnsubscribe: + return 0x2; + + default: + return 0; + } +} + +static int MqttPacketBaseSerialize(const MqttPacket *packet, Stream *stream) +{ + unsigned char typeAndFlags; + size_t remainingLength; + + typeAndFlags = ((packet->type & 0x0F) << 4) | + (MqttPacketFlags(packet) & 0x0F); + remainingLength = MqttPacketGetRemainingLength(packet); + + LOG_DEBUG("type:%02X (%s) flags:%02X", packet->type, + MqttPacketName(packet->type), MqttPacketFlags(packet)); + + if (StreamWrite(&typeAndFlags, 1, stream) != 1) + return -1; + + if (StreamWriteRemainingLength(remainingLength, stream) == -1) + return -1; + + return 0; +} + +static int MqttPacketWithIdSerialize(const MqttPacket *packet, Stream *stream) +{ + assert(MqttPacketHasId((const MqttPacket *) packet)); + + if (MqttPacketBaseSerialize(packet, stream) == -1) + return -1; + + if (StreamWriteUint16Be(packet->id, stream) == -1) + return -1; + + return 0; +} + +static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *stream) +{ + if (MqttPacketBaseSerialize(&packet->base, stream) == -1) + return -1; + + if (StreamWriteMqttStringBuf(&MqttProtocolId, stream) == -1) + return -1; + + if (StreamWrite(&MqttProtocolLevel, 1, stream) != 1) + return -1; + + if (StreamWrite(&packet->connectFlags, 1, stream) != 1) + return -1; + + if (StreamWriteUint16Be(packet->keepAlive, stream) == -1) + return -1; + + if (StreamWriteMqttStringBuf(&packet->clientId, stream) == -1) + return -1; + + if (packet->connectFlags & 0x04) + { + if (StreamWriteMqttStringBuf(&packet->willTopic, stream) == -1) + return -1; + + if (StreamWriteMqttStringBuf(&packet->willMessage, stream) == -1) + return -1; + } + + if (packet->connectFlags & 0x80) + { + if (StreamWriteMqttStringBuf(&packet->userName, stream) == -1) + return -1; + } + + if (packet->connectFlags & 0x40) + { + if (StreamWriteMqttStringBuf(&packet->password, stream) == -1) + return -1; + } + + return 0; +} + +static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream) +{ + if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) + return -1; + + if (StreamWriteMqttStringBuf(&packet->topicFilter, stream) == -1) + return -1; + + if (StreamWrite(&packet->qos, 1, stream) == -1) + return -1; + + return 0; +} + +static int MqttPacketUnsubscribeSerialize(const MqttPacketUnsubscribe *packet, Stream *stream) +{ + if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) + return -1; + + if (StreamWriteMqttStringBuf(&packet->topicFilter, stream) == -1) + return -1; + + return 0; +} + +static int MqttPacketPublishSerialize(const MqttPacketPublish *packet, Stream *stream) +{ + if (MqttPacketBaseSerialize((const MqttPacket *) packet, stream) == -1) + return -1; + + if (StreamWriteMqttStringBuf(&packet->topicName, stream) == -1) + return -1; + + LOG_DEBUG("qos:%d", MqttPacketPublishQos(packet)); + + if (MqttPacketPublishQos(packet) > 0) + { + if (StreamWriteUint16Be(packet->base.id, stream) == -1) + return -1; + } + + if (StreamWrite(packet->message.data, packet->message.len, stream) == -1) + return -1; + + return 0; +} + +int MqttPacketSerialize(const MqttPacket *packet, Stream *stream) +{ + MqttPacketSerializeFunc f = NULL; + + switch (packet->type) + { + case MqttPacketTypeConnect: + f = (MqttPacketSerializeFunc) MqttPacketConnectSerialize; + break; + + case MqttPacketTypeConnAck: + break; + + case MqttPacketTypePublish: + f = (MqttPacketSerializeFunc) MqttPacketPublishSerialize; + break; + + case MqttPacketTypePubAck: + f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; + break; + + case MqttPacketTypePubRec: + f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; + break; + + case MqttPacketTypePubRel: + f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; + break; + + case MqttPacketTypePubComp: + f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; + break; + + case MqttPacketTypeSubscribe: + f = (MqttPacketSerializeFunc) MqttPacketSubscribeSerialize; + break; + + case MqttPacketTypeSubAck: + break; + + case MqttPacketTypeUnsubscribe: + f = (MqttPacketSerializeFunc) MqttPacketUnsubscribeSerialize; + break; + + case MqttPacketTypeUnsubAck: + break; + + case MqttPacketTypePingReq: + f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; + break; + + case MqttPacketTypePingResp: + break; + + case MqttPacketTypeDisconnect: + f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; + break; + + default: + return -1; + } + + assert(f != NULL && "no serializer"); + + return f(packet, stream); +} diff --git a/src/serialize.h b/src/serialize.h new file mode 100644 index 0000000..ac8d38f --- /dev/null +++ b/src/serialize.h @@ -0,0 +1,9 @@ +#ifndef SERIALIZE_H +#define SERIALIZE_H + +typedef struct MqttPacket MqttPacket; +typedef struct Stream Stream; + +int MqttPacketSerialize(const MqttPacket *packet, Stream *stream); + +#endif diff --git a/src/socket.c b/src/socket.c new file mode 100644 index 0000000..91999b2 --- /dev/null +++ b/src/socket.c @@ -0,0 +1,92 @@ +#include "socket.h" + +#include <string.h> +#include <stdio.h> +#include <assert.h> + +#if defined(_WIN32) +#error not implemented yet +#define WIN32_MEAN_AND_LEAN 1 +#include <windows.h> +#else +#include <sys/types.h> +#include <sys/socket.h> +#include <netdb.h> +#include <unistd.h> +#include <arpa/inet.h> +#endif + +int SocketConnect(const char *host, short port) +{ + struct addrinfo hints, *servinfo, *p; + int rv; + char portstr[6]; + int sock; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + assert(snprintf(portstr, sizeof(portstr), "%hu", port) < (int) sizeof(portstr)); + + if ((rv = getaddrinfo(host, portstr, &hints, &servinfo)) != 0) + { + return -1; + } + + for (p = servinfo; p != NULL; p = p->ai_next) + { + if ((sock = socket(p->ai_family, p->ai_socktype, + p->ai_protocol)) == -1) + { + continue; + } + + if (connect(sock, p->ai_addr, p->ai_addrlen) == -1) + { + close(sock); + continue; + } + + break; + } + + freeaddrinfo(servinfo); + + if (p == NULL) + { + return -1; + } + + return sock; +} + +int SocketDisconnect(int sock) +{ +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +int SocketSendAll(int sock, const char *buf, size_t *len) +{ + size_t total = 0; + int rv; + size_t remaining = *len; + + while (remaining > 0) + { + if ((rv = send(sock, buf+total, remaining, 0)) == -1) + { + break; + } + total += rv; + remaining -= rv; + } + + *len = total; + + return rv == -1 ? -1 : 0; +} diff --git a/src/socket.h b/src/socket.h new file mode 100644 index 0000000..344fbc5 --- /dev/null +++ b/src/socket.h @@ -0,0 +1,12 @@ +#ifndef SOCKET_H +#define SOCKET_H + +#include <stdlib.h> + +int SocketConnect(const char *host, short port); + +int SocketDisconnect(int sock); + +int SocketSendAll(int sock, const char *buf, size_t *len); + +#endif diff --git a/src/socketstream.c b/src/socketstream.c new file mode 100644 index 0000000..3bcc411 --- /dev/null +++ b/src/socketstream.c @@ -0,0 +1,75 @@ +#include "socketstream.h" + +#include <assert.h> +#include <string.h> + +#include <arpa/inet.h> + +// close +#include <unistd.h> + +static int SocketStreamClose(Stream *base) +{ + int rv; + SocketStream *stream = (SocketStream *) base; + rv = close(stream->sock); + stream->sock = -1; + return rv; +} + +static int64_t SocketStreamRead(void *ptr, size_t size, Stream *stream) +{ + SocketStream *ss = (SocketStream *) stream; + size_t received = 0; + if (ss->sock == -1) + return -1; + while (received < size) + { + char *p = ((char *) ptr) + received; + ssize_t rv = recv(ss->sock, p, size - received, 0); + // Error + if (rv == -1) + return -1; + // TODO: Closed? + if (rv == 0) + break; + received += (size_t) rv; + } + return received; +} + +static int64_t SocketStreamWrite(const void *ptr, size_t size, Stream *stream) +{ + SocketStream *ss = (SocketStream *) stream; + size_t written = 0; + if (ss->sock == -1) + return -1; + while (written < size) + { + const char *p = ((char *) ptr) + written; + ssize_t rv = send(ss->sock, p, size - written, 0); + if (rv == -1) + return -1; + written += (size_t) rv; + } + return written; +} + +static const StreamOps SocketStreamOps = +{ + SocketStreamRead, + SocketStreamWrite, + SocketStreamClose, + NULL, + NULL +}; + +int SocketStreamOpen(SocketStream *stream, int sock) +{ + assert(stream != NULL); + assert(sock != -1); + memset(stream, 0, sizeof(*stream)); + stream->sock = sock; + stream->base.ops = &SocketStreamOps; + return 0; +} diff --git a/src/socketstream.h b/src/socketstream.h new file mode 100644 index 0000000..76d842a --- /dev/null +++ b/src/socketstream.h @@ -0,0 +1,16 @@ +#ifndef SOCKETSTREAM_H +#define SOCKETSTREAM_H + +#include "stream.h" + +typedef struct SocketStream SocketStream; + +struct SocketStream +{ + Stream base; + int sock; +}; + +int SocketStreamOpen(SocketStream *stream, int sock); + +#endif diff --git a/src/stream.c b/src/stream.c new file mode 100644 index 0000000..5296501 --- /dev/null +++ b/src/stream.c @@ -0,0 +1,77 @@ +#include "stream.h" +#include "misc.h" + +#include <stdio.h> +#include <string.h> +#include <assert.h> +#include <errno.h> + +// htons, ntohs +#include <arpa/inet.h> + +#define STREAM_CHECK_OP(stream, op) \ + do { if ((stream->ops->op) == NULL) \ + { \ + errno = ENOTSUP; \ + return -1; \ + } } while (0) + +int StreamClose(Stream *stream) +{ + if (stream->ops->close) + { + return stream->ops->close(stream); + } + return 0; +} + +int64_t StreamRead(void *ptr, size_t size, Stream *stream) +{ + STREAM_CHECK_OP(stream, read); + int64_t rv = stream->ops->read(ptr, size, stream); +#if defined(STREAM_HEXDUMP_READ) + if (rv >= 0) + { + printf("READ %lu bytes:\n", size); + DumpHex(ptr, size); + } +#endif + return rv; +} + +int64_t StreamReadUint16Be(uint16_t *v, Stream *stream) +{ + STREAM_CHECK_OP(stream, read); + if (StreamRead(v, 2, stream) != 2) + return -1; + *v = ntohs(*v); + return 2; +} + +int64_t StreamWrite(const void *ptr, size_t size, Stream *stream) +{ + STREAM_CHECK_OP(stream, write); +#if defined(STREAM_HEXDUMP_WRITE) + printf("WRITE %lu bytes:\n", size); + DumpHex(ptr, size); +#endif + return stream->ops->write(ptr, size, stream); +} + +int64_t StreamWriteUint16Be(uint16_t v, Stream *stream) +{ + v = htons(v); + return StreamWrite(&v, sizeof(v), stream); +} + +int StreamSeek(Stream *stream, int64_t offset, int whence) +{ + STREAM_CHECK_OP(stream, seek); + return stream->ops->seek(stream, offset, whence); +} + +int64_t StreamTell(Stream *stream) +{ + STREAM_CHECK_OP(stream, tell); + return stream->ops->tell(stream); +} diff --git a/src/stream.h b/src/stream.h new file mode 100644 index 0000000..b577dc6 --- /dev/null +++ b/src/stream.h @@ -0,0 +1,50 @@ +#ifndef STREAM_H +#define STREAM_H + +#include <stdlib.h> +#include <stdint.h> + +#include "stringbuf.h" + +#ifndef SEEK_SET +#define SEEK_SET (-1) +#endif + +#ifndef SEEK_CUR +#define SEEK_CUR (-2) +#endif + +#ifndef SEEK_END +#define SEEK_END (-3) +#endif + +typedef struct Stream Stream; +typedef struct StreamOps StreamOps; + +struct Stream +{ + const StreamOps *ops; +}; + +struct StreamOps +{ + int64_t (*read)(void *ptr, size_t size, Stream *stream); + int64_t (*write)(const void *ptr, size_t size, Stream *stream); + int (*close)(Stream *stream); + int (*seek)(Stream *stream, int64_t offset, int whence); + int64_t (*tell)(Stream *stream); +}; + +int StreamClose(Stream *stream); + +int64_t StreamRead(void *ptr, size_t size, Stream *stream); +int64_t StreamReadUint16Be(uint16_t *v, Stream *stream); + +int64_t StreamWrite(const void *ptr, size_t size, Stream *stream); +int64_t StreamWriteUint16Be(uint16_t v, Stream *stream); + +int StreamSeek(Stream *stream, int64_t offset, int whence); + +int64_t StreamTell(Stream *stream); + +#endif diff --git a/src/stream_mqtt.c b/src/stream_mqtt.c new file mode 100644 index 0000000..25d2e56 --- /dev/null +++ b/src/stream_mqtt.c @@ -0,0 +1,98 @@ +#include "stream_mqtt.h" +#include "stringbuf.h" + +#include <string.h> + +int64_t StreamReadMqttString(char **s, size_t *len, Stream *stream) +{ + StringBuf buf; + int64_t rv; + + if ((rv = StreamReadMqttStringBuf(&buf, stream)) == -1) + return -1; + + *s = buf.data; + *len = buf.len; + + return rv; +} + +int64_t StreamWriteMqttString(const char *s, int len, Stream *stream) +{ + StringBuf buf; + + if (len < 0) + len = strlen(s); + + buf.data = (char *) s; + buf.len = len; + + return StreamWriteMqttStringBuf(&buf, stream); +} + +int64_t StreamReadMqttStringBuf(struct StringBuf *buf, Stream *stream) +{ + uint16_t len; + + if (StreamReadUint16Be(&len, stream) == -1) + return -1; + + if (StringBufInit(buf, len) == -1) + return -1; + + if (StreamRead(buf->data, len, stream) == -1) + { + StringBufDeinit(buf); + return -1; + } + + buf->len = len; + + return len+2; +} + +int64_t StreamWriteMqttStringBuf(const struct StringBuf *buf, Stream *stream) +{ + if (StreamWriteUint16Be(buf->len, stream) == -1) + return -1; + + if (StreamWrite(buf->data, buf->len, stream) == -1) + return -1; + + return 2 + buf->len; +} + +int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream) +{ + size_t multiplier = 1; + unsigned char encodedByte; + *remainingLength = 0; + do + { + if (StreamRead(&encodedByte, 1, stream) != 1) + return -1; + *remainingLength += (encodedByte & 127) * multiplier; + if (multiplier > 128*128*128) + return -1; + multiplier *= 128; + } + while ((encodedByte & 128) != 0); + return 0; +} + +int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream) +{ + size_t nbytes = 0; + do + { + unsigned char encodedByte = remainingLength % 128; + remainingLength /= 128; + if (remainingLength > 0) + encodedByte |= 128; + if (StreamWrite(&encodedByte, 1, stream) != 1) + return -1; + ++nbytes; + } + while (remainingLength > 0); + return nbytes; +} diff --git a/src/stream_mqtt.h b/src/stream_mqtt.h new file mode 100644 index 0000000..458f2ac --- /dev/null +++ b/src/stream_mqtt.h @@ -0,0 +1,17 @@ +#ifndef STREAM_MQTT_H +#define STREAM_MQTT_H + +#include "stream.h" + +int64_t StreamReadMqttString(char **s, size_t *len, Stream *stream); +int64_t StreamWriteMqttString(const char *s, int len, Stream *stream); + +struct StringBuf; + +int64_t StreamReadMqttStringBuf(struct StringBuf *buf, Stream *stream); +int64_t StreamWriteMqttStringBuf(const struct StringBuf *buf, Stream *stream); + +int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream); +int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream); + +#endif diff --git a/src/stringbuf.c b/src/stringbuf.c new file mode 100644 index 0000000..bc71dc3 --- /dev/null +++ b/src/stringbuf.c @@ -0,0 +1,73 @@ +#include "stringbuf.h" + +#include <stdlib.h> +#include <string.h> +#include <assert.h> + +int StringBufInit(StringBuf *buf, size_t size) +{ + assert(buf != NULL); + memset(buf, 0, sizeof(*buf)); + return StringBufGrow(buf, size); +} + +int StringBufInitFromCString(StringBuf *buf, const char *s, int len) +{ + if (len < 0) + len = strlen(s); + return StringBufInitFromData(buf, s, len); +} + +int StringBufInitFromData(StringBuf *buf, const void *ptr, size_t size) +{ + if (StringBufInit(buf, size) != 0) + return -1; + memcpy(buf->data, ptr, size); + buf->len = size; + return 0; +} + +void StringBufDeinit(StringBuf *buf) +{ + assert(buf != NULL); + if (buf->size > 0 && buf->data) + free(buf->data); + memset(buf, 0, sizeof(*buf)); +} + +size_t StringBufAvailable(StringBuf *buf) +{ + assert(buf != NULL); + assert(buf->data != NULL); + assert(buf->len <= buf->size); + return buf->size - buf->len; +} + +int StringBufGrow(StringBuf *buf, size_t size) +{ + assert(buf != NULL); + size_t newSize = buf->size + size; + char *ptr = realloc(buf->data, newSize+1); + if (!ptr) + return -1; + buf->data = ptr; + buf->size = newSize; + buf->data[buf->size] = '\0'; + return 0; +} + +int StringBufAppendData(StringBuf *buf, const void *ptr, size_t size) +{ + assert(buf != NULL); + assert(buf->data != NULL); + assert(ptr != NULL); + assert(size > 0); + if (StringBufAvailable(buf) < size) + { + if (StringBufGrow(buf, size) == -1) + return -1; + } + memcpy(buf->data + buf->len, ptr, size); + buf->len += size; + return 0; +} diff --git a/src/stringbuf.h b/src/stringbuf.h new file mode 100644 index 0000000..dda32cd --- /dev/null +++ b/src/stringbuf.h @@ -0,0 +1,31 @@ +#ifndef STRINGBUF_H +#define STRINGBUF_H + +#include <stdlib.h> + +typedef struct StringBuf StringBuf; + +struct StringBuf +{ + char *data; + int size; + int len; +}; + +int StringBufInit(StringBuf *buf, size_t size); + +int StringBufInitFromCString(StringBuf *buf, const char *s, int len); + +int StringBufInitFromData(StringBuf *buf, const void *ptr, size_t size); + +void StringBufDeinit(StringBuf *buf); + +size_t StringBufAvailable(StringBuf *buf); + +int StringBufGrow(StringBuf *buf, size_t size); + +int StringBufAppendData(StringBuf *buf, const void *ptr, size_t size); + +#define StaticStringBuf(S) { "" S, -1, sizeof(S)-1 } + +#endif diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt new file mode 100644 index 0000000..1a736ab --- /dev/null +++ b/tools/CMakeLists.txt @@ -0,0 +1,7 @@ +ADD_LIBRARY(getopt OBJECT getopt.c) + +ADD_EXECUTABLE(pub pub.c $<TARGET_OBJECTS:getopt>) +TARGET_LINK_LIBRARIES(pub mqtt) + +ADD_EXECUTABLE(sub sub.c $<TARGET_OBJECTS:getopt>) +TARGET_LINK_LIBRARIES(sub mqtt) diff --git a/tools/amalgamate.py b/tools/amalgamate.py new file mode 100644 index 0000000..a807603 --- /dev/null +++ b/tools/amalgamate.py @@ -0,0 +1,90 @@ +import io +import os +import sys + +this_dir = os.path.dirname(os.path.abspath(__file__)) + +src_dir = os.path.join(this_dir, '..', 'src') + +sources = ( + 'queue.h', + 'log.h', + 'misc.c', + 'stringbuf.c', + 'stream.c', + 'socketstream.c', + 'stringstream.c', + 'stream_mqtt.c', + 'socket.c', + 'packet.c', + 'serialize.c', + 'deserialize.c', + 'client.c' +) + + +def is_header(filename): + return os.path.splitext(filename)[1] == '.h' + + +def get_header(src): + root, ext = os.path.splitext(src) + return root + '.h' + + +def read_file(filename): + def tounicode(s): + if sys.version_info[0] == 2: + return s.decode('utf-8') + else: + return s + with open(filename, 'r') as fp: + buf = io.StringIO() + for line in fp: + if line.startswith('#include "'): + if line[10:].startswith('mqtt.h'): + pass + else: + continue + buf.write(tounicode(line)) + return buf.getvalue() + + +def file_header(filename): + filename = os.path.basename(filename) + # how long lines we create + linelen = 72 + # how much space left after the necessary comment markup + chars = linelen - 4 + # how much padding in total for filename + padding = chars - len(filename) + padding_l = padding // 2 + padding_r = padding - padding_l + lines = ( + '', + '/*' + '*'*chars + '*/', + '/*' + ' '*padding_l + filename + ' '*padding_r + '*/', + '/*' + '*'*chars + '*/', + '\n', + ) + return '\n'.join(lines) + + +def write_file(output, srcfilename): + output.write(file_header(srcfilename)) + output.write(read_file(srcfilename)) + + +output_filename = sys.argv[1] + +with open(output_filename, 'w') as out: + for source in sources: + path = os.path.join(src_dir, source) + + if is_header(path): + write_file(out, path) + else: + header = get_header(path) + if os.path.isfile(header): + write_file(out, header) + write_file(out, path) diff --git a/tools/getopt.c b/tools/getopt.c new file mode 100644 index 0000000..5277ed0 --- /dev/null +++ b/tools/getopt.c @@ -0,0 +1,358 @@ +#include <assert.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include "getopt.h" + +/* + * Standard getopt global variables. optreset starts as non-zero in order to + * trigger initialization behaviour. + */ +const char * optarg = NULL; +int optind = 1; +int opterr = 1; +int optreset = 1; + +/* + * Quasi-internal global variables -- these are used via GETOPT macros. + */ +const char * getopt_dummy = "(dummy)"; +int getopt_initialized = 0; + +/* + * Internal variables. + */ +static const char * cmdname = NULL; +static struct opt { + const char * os; + size_t olen; + int hasarg; +} * opts = NULL; +static size_t nopts; +static size_t opt_missing; +static size_t opt_default; +static size_t opt_found; +static const char * packedopts; +static char popt[3]; +static int atexit_registered = 0; + +/* Print a message. */ +#define PRINTMSG(...) do { \ + if (cmdname != NULL) \ + fprintf(stderr, "%s: ", cmdname); \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n"); \ +} while (0) + +/* Print an error message and die. */ +#define DIE(...) do { \ + PRINTMSG(__VA_ARGS__); \ + abort(); \ +} while (0) + +/* Print a warning, if warnings are enabled. */ +#define WARN(...) do { \ + if (opterr == 0) \ + break; \ + if (opt_missing != opt_default) \ + break; \ + PRINTMSG(__VA_ARGS__); \ +} while (0) + +/* Free allocated options array. */ +static void +atexit_handler(void) +{ + + free(opts); + opts = NULL; +} + +/* Reset internal state. */ +static void +reset(int argc, char * const argv[]) +{ + const char * p; + + /* If we have arguments, stash argv[0] for error messages. */ + if (argc > 0) { + /* Find the basename, without leading directories. */ + for (p = cmdname = argv[0]; *p != '\0'; p++) { + if (*p == '/') + cmdname = p + 1; + } + } + + /* Discard any registered command-line options. */ + free(opts); + opts = NULL; + + /* Register atexit handler if we haven't done so already. */ + if (!atexit_registered) { + atexit(atexit_handler); + atexit_registered = 1; + } + + /* We will start scanning from the first option. */ + optind = 1; + + /* We're not in the middle of any packed options. */ + packedopts = NULL; + + /* We haven't found any option yet. */ + opt_found = (size_t)(-1); + + /* We're not initialized yet. */ + getopt_initialized = 0; + + /* Finished resetting state. */ + optreset = 0; +} + +/* Search for an option string. */ +static size_t +searchopt(const char * os) +{ + size_t i; + + /* Scan the array of options. */ + for (i = 0; i < nopts; i++) { + /* Is there an option in this slot? */ + if (opts[i].os == NULL) + continue; + + /* Does this match up to the length of the option string? */ + if (strncmp(opts[i].os, os, opts[i].olen)) + continue; + + /* Do we have <option>\0 or <option>= ? */ + if ((os[opts[i].olen] == '\0') || (os[opts[i].olen] == '=')) + return (i); + } + + /* Not found. */ + return (opt_default); +} + +const char * +getopt(int argc, char * const argv[]) +{ + const char * os = NULL; + const char * canonical_os = NULL; + + /* No argument yet. */ + optarg = NULL; + + /* Reset the getopt state if needed. */ + if (optreset) + reset(argc, argv); + + /* If not initialized, return dummy option. */ + if (!getopt_initialized) + return (GETOPT_DUMMY); + + /* If we've run out of arguments, we're done. */ + if (optind >= argc) + return (NULL); + + /* + * If we're not already in the middle of a packed single-character + * options, see if we should start. + */ + if ((packedopts == NULL) && (argv[optind][0] == '-') && + (argv[optind][1] != '-') && (argv[optind][1] != '\0')) { + /* We have one or more single-character options. */ + packedopts = &argv[optind][1]; + } + + /* If we're processing single-character options, fish one out. */ + if (packedopts != NULL) { + /* Construct the option string. */ + popt[0] = '-'; + popt[1] = *packedopts; + popt[2] = '\0'; + os = popt; + + /* We've done this character. */ + packedopts++; + + /* Are we done with this string? */ + if (*packedopts == '\0') { + packedopts = NULL; + optind++; + } + } + + /* If we don't have an option yet, do we have dash-dash? */ + if ((os == NULL) && (argv[optind][0] == '-') && + (argv[optind][1] == '-')) { + /* If this is not "--\0", it's an option. */ + if (argv[optind][2] != '\0') + os = argv[optind]; + + /* Either way, we want to eat the string. */ + optind++; + } + + /* If we have found nothing which looks like an option, we're done. */ + if (os == NULL) + return (NULL); + + /* Search for the potential option. */ + opt_found = searchopt(os); + + /* If the option is not registered, give up now. */ + if (opt_found == opt_default) { + WARN("unknown option: %s", os); + return (os); + } + + /* The canonical option string is the one registered. */ + canonical_os = opts[opt_found].os; + + /* Does the option take an argument? */ + if (opts[opt_found].hasarg) { + /* + * If we're processing packed single-character options, the + * rest of the string is the argument to this option. + */ + if (packedopts != NULL) { + optarg = packedopts; + packedopts = NULL; + optind++; + } + + /* + * If the option string is <option>=<value>, extract that + * value as the option argument. + */ + if (os[opts[opt_found].olen] == '=') + optarg = &os[opts[opt_found].olen + 1]; + + /* + * If we don't have an argument yet, take one from the + * remaining command line. + */ + if ((optarg == NULL) && (optind < argc)) + optarg = argv[optind++]; + + /* If we still have no option, declare it MIA. */ + if (optarg == NULL) { + WARN("option requires an argument: %s", + opts[opt_found].os); + opt_found = opt_missing; + } + } else { + /* If we have --foo=bar, something went wrong. */ + if (os[opts[opt_found].olen] == '=') { + WARN("option doesn't take an argument: %s", + opts[opt_found].os); + opt_found = opt_default; + } + } + + /* Return the canonical option string. */ + return (canonical_os); +} + +size_t +getopt_lookup(const char * os) +{ + + /* Can't reset here. */ + if (optreset) + DIE("Can't reset in the middle of getopt loop"); + + /* We should only be called after initialization is complete. */ + assert(getopt_initialized); + + /* GETOPT_DUMMY should never get passed back to us. */ + assert(os != GETOPT_DUMMY); + + /* + * Make sure the option passed back to us corresponds to the one we + * found earlier. + */ + assert((opt_found == opt_missing) || (opt_found == opt_default) || + ((opt_found < nopts) && (strcmp(os, opts[opt_found].os) == 0))); + + /* Return the option number we identified earlier. */ + return (opt_found); +} + +void +getopt_register_opt(const char * os, size_t ln, int hasarg) +{ + + /* Can't reset here. */ + if (optreset) + DIE("Can't reset in the middle of getopt loop"); + + /* We should only be called during initialization. */ + assert(!getopt_initialized); + + /* We should have space allocated for registering options. */ + assert(opts != NULL); + + /* We should not have registered an option here yet. */ + assert(opts[ln].os == NULL); + + /* Options should be "-X" or "--foo". */ + if ((os[0] != '-') || (os[1] == '\0') || + ((os[1] == '-') && (os[2] == '\0')) || + ((os[1] != '-') && (os[2] != '\0'))) + DIE("Not a valid command-line option: %s", os); + + /* Make sure we haven't already registered this option. */ + if (searchopt(os) != opt_default) + DIE("Command-line option registered twice: %s", os); + + /* Record option. */ + opts[ln].os = os; + opts[ln].olen = strlen(os); + opts[ln].hasarg = hasarg; +} + +void +getopt_register_missing(size_t ln) +{ + + /* Can't reset here. */ + if (optreset) + DIE("Can't reset in the middle of getopt loop"); + + /* We should only be called during initialization. */ + assert(!getopt_initialized); + + /* Record missing-argument value. */ + opt_missing = ln; +} + +void +getopt_setrange(size_t ln) +{ + size_t i; + + /* Can't reset here. */ + if (optreset) + DIE("Can't reset in the middle of getopt loop"); + + /* We should only be called during initialization. */ + assert(!getopt_initialized); + + /* Allocate space for options. */ + opts = malloc(ln * sizeof(struct opt)); + if ((ln > 0) && (opts == NULL)) + DIE("Failed to allocate memory in getopt"); + + /* Initialize options. */ + for (i = 0; i < ln; i++) + opts[i].os = NULL; + + /* Record the number of (potential) options. */ + nopts = ln; + + /* Record default missing-argument and no-such-option values. */ + opt_missing = opt_default = ln + 1; +} diff --git a/tools/getopt.h b/tools/getopt.h new file mode 100644 index 0000000..5ed3145 --- /dev/null +++ b/tools/getopt.h @@ -0,0 +1,175 @@ +#ifndef _GETOPT_H_ +#define _GETOPT_H_ + +#include <setjmp.h> +#include <stddef.h> + +/** + * This getopt implementation parses options of the following forms: + * -a -b -c foo (single-character options) + * -abc foo (packed single-character options) + * -abcfoo (packed single-character options and an argument) + * --foo bar (long option) + * --foo=bar (long option and argument separated by '=') + * + * It does not support abbreviated options (e.g., interpreting --foo as + * --foobar when there are no other --foo* options) since that misfeature + * results in breakage when new options are added. It also does not support + * options appearing after non-options (e.g., "cp foo bar -R") since that is + * a horrible GNU perversion. + */ + +/* Work around LLVM bug. */ +#ifdef __clang__ +#warning Working around bug in LLVM optimizer +#warning For more details see https://llvm.org/bugs/show_bug.cgi?id=27190 +#define DO_SETJMP _DO_SETJMP(__LINE__) +#define _DO_SETJMP(x) __DO_SETJMP(x) +#define __DO_SETJMP(x) \ + void * getopt_initloop = && getopt_initloop_ ## x; \ + getopt_initloop_ ## x: +#define DO_LONGJMP \ + goto *getopt_initloop +#else +#define DO_SETJMP \ + sigjmp_buf getopt_initloop; \ + if (!getopt_initialized) \ + sigsetjmp(getopt_initloop, 0) +#define DO_LONGJMP \ + siglongjmp(getopt_initloop, 1) +#endif + +/* Avoid namespace collisions with libc getopt. */ +#define getopt libcperciva_getopt +#define optarg libcperciva_optarg +#define optind libcperciva_optind +#define opterr libcperciva_opterr +#define optreset libcperciva_optreset + +/* Standard getopt global variables. */ +extern const char * optarg; +extern int optind, opterr, optreset; + +/* Dummy option string, equal to "(dummy)". */ +#define GETOPT_DUMMY getopt_dummy + +/** + * GETOPT(argc, argv): + * When called for the first time (or the first time after optreset is set to + * a nonzero value), return GETOPT_DUMMY, aka. "(dummy)". Thereafter, return + * the next option string and set optarg / optind appropriately; abort if not + * properly initialized when not being called for the first time. + */ +#define GETOPT(argc, argv) getopt(argc, argv) + +/** + * GETOPT_SWITCH(ch): + * Jump to the appropriate GETOPT_OPT, GETOPT_OPTARG, GETOPT_MISSING_ARG, or + * GETOPT_DEFAULT based on the option string ${ch}. When called for the first + * time, perform magic to index the options. + * + * GETOPT_SWITCH(ch) is equivalent to "switch (ch)" in a standard getopt loop. + */ +#define GETOPT_SWITCH(ch) \ + volatile size_t getopt_ln_min = __LINE__; \ + volatile size_t getopt_ln = getopt_ln_min - 1; \ + volatile int getopt_default_missing = 0; \ + DO_SETJMP; \ + switch (getopt_initialized ? getopt_lookup(ch) + getopt_ln_min : getopt_ln++) + +/** + * GETOPT_OPT(os): + * Jump to this point when the option string ${os} is passed to GETOPT_SWITCH. + * + * GETOPT_OPT("-x") is equivalent to "case 'x'" in a standard getopt loop + * which has an optstring containing "x". + */ +#define GETOPT_OPT(os) _GETOPT_OPT(os, __LINE__) +#define _GETOPT_OPT(os, ln) __GETOPT_OPT(os, ln) +#define __GETOPT_OPT(os, ln) \ + case ln: \ + if (getopt_initialized) \ + goto getopt_skip_ ## ln; \ + getopt_register_opt(os, ln - getopt_ln_min, 0); \ + DO_LONGJMP; \ + getopt_skip_ ## ln + +/** + * GETOPT_OPTARG(os): + * Jump to this point when the option string ${os} is passed to GETOPT_SWITCH, + * unless no argument is available, in which case jump to GETOPT_MISSING_ARG + * (if present) or GETOPT_DEFAULT (if not). + * + * GETOPT_OPTARG("-x") is equivalent to "case 'x'" in a standard getopt loop + * which has an optstring containing "x:". + */ +#define GETOPT_OPTARG(os) _GETOPT_OPTARG(os, __LINE__) +#define _GETOPT_OPTARG(os, ln) __GETOPT_OPTARG(os, ln) +#define __GETOPT_OPTARG(os, ln) \ + case ln: \ + if (getopt_initialized) \ + goto getopt_skip_ ## ln; \ + getopt_register_opt(os, ln - getopt_ln_min, 1); \ + DO_LONGJMP; \ + getopt_skip_ ## ln + +/** + * GETOPT_MISSING_ARG: + * Jump to this point if an option string specified in GETOPT_OPTARG is seen + * but no argument is available. + * + * GETOPT_MISSING_ARG is equivalent to "case ':'" in a standard getopt loop + * which has an optstring starting with ":". As such, it also has the effect + * of disabling warnings about invalid options, as if opterr had been zeroed. + */ +#define GETOPT_MISSING_ARG _GETOPT_MISSING_ARG(__LINE__) +#define _GETOPT_MISSING_ARG(ln) __GETOPT_MISSING_ARG(ln) +#define __GETOPT_MISSING_ARG(ln) \ + case ln: \ + if (getopt_initialized) \ + goto getopt_skip_ ## ln; \ + getopt_register_missing(ln - getopt_ln_min); \ + DO_LONGJMP; \ + getopt_skip_ ## ln + +/** + * GETOPT_DEFAULT: + * Jump to this point if an unrecognized option is seen or if an option + * specified in GETOPT_OPTARG is seen, no argument is available, and there is + * no GETOPT_MISSING_ARG label. + * + * GETOPT_DEFAULT is equivalent to "case '?'" in a standard getopt loop. + * + * NOTE: This MUST be present in the GETOPT_SWITCH statement, and MUST occur + * after all other GETOPT_* labels. + */ +#define GETOPT_DEFAULT _GETOPT_DEFAULT(__LINE__) +#define _GETOPT_DEFAULT(ln) __GETOPT_DEFAULT(ln) +#define __GETOPT_DEFAULT(ln) \ + goto getopt_skip_ ## ln; \ + case ln: \ + getopt_initialized = 1; \ + break; \ + default: \ + if (getopt_initialized) \ + goto getopt_skip_ ## ln; \ + if (!getopt_default_missing) { \ + getopt_setrange(ln - getopt_ln_min); \ + getopt_default_missing = 1; \ + } \ + DO_LONGJMP; \ + getopt_skip_ ## ln + +/* + * The back-end implementation. These should be considered internal + * interfaces and not used directly. + */ +const char * getopt(int, char * const []); +size_t getopt_lookup(const char *); +void getopt_register_opt(const char *, size_t, int); +void getopt_register_missing(size_t); +void getopt_setrange(size_t); +extern const char * getopt_dummy; +extern int getopt_initialized; + +#endif /* !_GETOPT_H_ */ diff --git a/tools/pub.c b/tools/pub.c new file mode 100644 index 0000000..9b5af82 --- /dev/null +++ b/tools/pub.c @@ -0,0 +1,96 @@ +#include <stdlib.h> +#include <stdio.h> + +#include "mqtt.h" +#include "getopt.h" + +struct options +{ + int qos; + int retain; + const char *topic; + const char *message; +}; + +void onConnect(MqttClient *client, MqttConnectionStatus status, + int sessionPresent) +{ + struct options *options = (struct options *) MqttClientGetUserData(client); + (void) client; + printf("onConnect rv=%d sessionPresent=%d\n", status, sessionPresent); + MqttClientPublishCString(client, options->qos, options->retain, + options->topic, options->message); + if (options->qos == 0) + MqttClientDisconnect(client); +} + +void onPublish(MqttClient *client, int id) +{ + printf("onPublish id=%d\n", id); + MqttClientDisconnect(client); +} + +void usage(const char *prog) +{ + fprintf(stderr, "%s [--qos QOS] [--retain] [--topic TOPIC] [--message MESSAGE]\n", + prog); + exit(1); +} + +int main(int argc, char **argv) +{ + MqttClient *client; + const char *opt; + struct options options; + + options.qos = 0; + options.retain = 0; + options.topic = "my/topic"; + options.message = "hello, world!"; + + while ((opt = GETOPT(argc, argv)) != NULL) + { + GETOPT_SWITCH(opt) + { + GETOPT_OPTARG("--qos"): + options.qos = strtol(optarg, NULL, 10); + if (options.qos < 0 || options.qos > 2) + { + fprintf(stderr, "invalid qos: %s\n", optarg); + return 1; + } + break; + + GETOPT_OPT("--retain"): + options.retain = 1; + break; + + GETOPT_OPTARG("--topic"): + options.topic = optarg; + break; + + GETOPT_OPTARG("--message"): + options.message = optarg; + break; + + GETOPT_MISSING_ARG: + fprintf(stderr, "missing argument to: %s\n", opt); + + GETOPT_DEFAULT: + usage(argv[0]); + break; + } + } + + client = MqttClientNew(NULL, 1); + + MqttClientSetOnConnect(client, onConnect); + MqttClientSetOnPublish(client, onPublish); + MqttClientSetUserData(client, &options); + + MqttClientConnect(client, "test.mosquitto.org", 1883, 60); + + MqttClientRun(client); + + return 0; +} diff --git a/tools/sub.c b/tools/sub.c new file mode 100644 index 0000000..634c1e3 --- /dev/null +++ b/tools/sub.c @@ -0,0 +1,103 @@ +#include <stdlib.h> +#include <stdio.h> + +#include "mqtt.h" +#include "getopt.h" + +struct options +{ + int qos; + const char *topic; + int clean; + const char *client_id; +}; + +void onConnect(MqttClient *client, MqttConnectionStatus status, + int sessionPresent) +{ + struct options *options = (struct options *) MqttClientGetUserData(client); + (void) client; + printf("onConnect rv=%d sessionPresent=%d\n", status, sessionPresent); + MqttClientSubscribe(client, options->topic, options->qos); +} + +void onSubscribe(MqttClient *client, int id, MqttSubscriptionStatus status) +{ + (void) client; + printf("onSubscribe id=%d status=%d\n", id, status); +} + +void onMessage(MqttClient *client, const char *topic, const void *data, + size_t size) +{ + (void) client; + printf("onMessage topic=<%s> message=<%.*s>\n", topic, (int) size, + (char *) data); + // MqttClientUnsubscribe(client, topic); +} + +void usage(const char *prog) +{ + fprintf(stderr, "%s [--qos QOS] [--topic TOPIC] [--clean] [--id ID]\n", + prog); + exit(1); +} + +int main(int argc, char **argv) +{ + MqttClient *client; + const char *opt; + struct options options; + + options.qos = 0; + options.topic = "$SYS/broker/load/messages/#"; + options.clean = 1; + options.client_id = NULL; + + while ((opt = GETOPT(argc, argv)) != NULL) + { + GETOPT_SWITCH(opt) + { + GETOPT_OPTARG("--qos"): + options.qos = strtol(optarg, NULL, 10); + if (options.qos < 0 || options.qos > 2) + { + fprintf(stderr, "invalid qos: %s\n", optarg); + return 1; + } + break; + + GETOPT_OPTARG("--topic"): + options.topic = optarg; + break; + + GETOPT_OPT("--no-clean"): + options.clean = 0; + break; + + GETOPT_OPTARG("--id"): + options.client_id = optarg; + break; + + GETOPT_MISSING_ARG: + fprintf(stderr, "missing argument to: %s\n", opt); + + GETOPT_DEFAULT: + usage(argv[0]); + break; + } + } + + client = MqttClientNew(options.client_id, options.clean); + + MqttClientSetOnConnect(client, onConnect); + MqttClientSetOnSubscribe(client, onSubscribe); + MqttClientSetOnMessage(client, onMessage); + MqttClientSetUserData(client, &options); + + MqttClientConnect(client, "test.mosquitto.org", 1883, 60); + + MqttClientRun(client); + + return 0; +} |
