diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-18 09:39:08 +0200 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-18 09:39:08 +0200 |
| commit | 296da7a73501fb7ce8c0f4c0e18060ebf7eada97 (patch) | |
| tree | e776d2a2e7aff048fb82c7a87a58fbabb0c73395 /amalgamation | |
| parent | 7aeef53b089272f4633cc40512296bfd884a58d4 (diff) | |
| download | mqtt-296da7a73501fb7ce8c0f4c0e18060ebf7eada97.tar.gz mqtt-296da7a73501fb7ce8c0f4c0e18060ebf7eada97.zip | |
Update amalgamation
Diffstat (limited to 'amalgamation')
| -rw-r--r-- | amalgamation/mqtt.c | 2853 | ||||
| -rw-r--r-- | amalgamation/mqtt.h | 9 |
2 files changed, 1449 insertions, 1413 deletions
diff --git a/amalgamation/mqtt.c b/amalgamation/mqtt.c index 1c8f88a..c0ff379 100644 --- a/amalgamation/mqtt.c +++ b/amalgamation/mqtt.c @@ -4529,6 +4529,284 @@ int n, r, l; #endif /**********************************************************************/ +/* socket.h */ +/**********************************************************************/ + +#ifndef SOCKET_H +#define SOCKET_H + + +#include <stdlib.h> +#include <stdint.h> + +#if defined(_WIN32) +#define SocketErrno (WSAGetLastError()) +#define SOCKET_EINPROGRESS (WSAEWOULDBLOCK) +#else +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/select.h> +#include <netdb.h> +#include <unistd.h> +#include <arpa/inet.h> +#include <fcntl.h> +#include <errno.h> +#define SocketErrno (errno) +#define SOCKET_EINPROGRESS (EINPROGRESS) +#endif + +int SocketConnect(const char *host, short port, int nonblocking); + +int SocketDisconnect(int sock); + +int SocketSendAll(int sock, const char *buf, size_t *len); + +enum +{ + EV_READ = 1, + EV_WRITE = 2 +}; + +int SocketSelect(int sock, int *events, int timeout); + +int64_t SocketRecv(int sock, void *buf, size_t len, int flags); + +int64_t SocketSend(int sock, const void *buf, size_t len, int flags); + +void SocketSetNonblocking(int sock, int nb); + +int SocketGetError(int sock, int *error); + +int SocketWouldBlock(int error); + +#endif + +/**********************************************************************/ +/* socket.c */ +/**********************************************************************/ + + +#include <string.h> +#include <stdio.h> +#include <assert.h> + +#if defined(_WIN32) +static int InitializeWsa() +{ + WSADATA wsa; + int rc; + if ((rc = WSAStartup(MAKEWORD(2, 2), &wsa)) != 0) + { + LOG_ERROR("WSAStartup failed: %d", rc); + return -1; + } + return 0; +} + +#define close closesocket +#endif + +int SocketConnect(const char *host, short port, int nonblocking) +{ + struct addrinfo hints, *servinfo = NULL, *p = NULL; + int rv; + char portstr[6]; + int sock; + +#if defined(_WIN32) + if (InitializeWsa() != 0) + { + return -1; + } +#endif + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + assert(snprintf(portstr, sizeof(portstr), "%hu", port) < (int) sizeof(portstr)); + + if ((rv = getaddrinfo(host, portstr, &hints, &servinfo)) != 0) + { + goto cleanup; + } + + for (p = servinfo; p != NULL; p = p->ai_next) + { + if ((sock = socket(p->ai_family, p->ai_socktype, + p->ai_protocol)) == -1) + { + continue; + } + + if (nonblocking) + { + SocketSetNonblocking(sock, 1); + } + + if (connect(sock, p->ai_addr, p->ai_addrlen) == -1) + { + int err = SocketErrno; + if (err == SOCKET_EINPROGRESS) + break; + close(sock); + continue; + } + + break; + } + +cleanup: + + if (servinfo) + { + freeaddrinfo(servinfo); + } + + if (p == NULL) + { +#if defined(_WIN32) + WSACleanup(); +#endif + return -1; + } + + return sock; +} + +int SocketDisconnect(int sock) +{ + int rc = close(sock); +#if defined(_WIN32) + WSACleanup(); +#endif + return rc; +} + +int64_t SocketRecv(int sock, void *buf, size_t len, int flags) +{ + return recv(sock, buf, len, flags); +} + +int64_t SocketSend(int sock, const void *buf, size_t len, int flags) +{ + return send(sock, buf, len, flags); +} + +int SocketSendAll(int sock, const char *buf, size_t *len) +{ + size_t total = 0; + int rv; + size_t remaining = *len; + + while (remaining > 0) + { + if ((rv = send(sock, buf+total, remaining, 0)) == -1) + { + break; + } + total += rv; + remaining -= rv; + } + + *len = total; + + return rv == -1 ? -1 : 0; +} + +int SocketSelect(int sock, int *events, int timeout) +{ + fd_set rfd, wfd; + struct timeval tv; + int rv; + + assert(sock != -1); + assert(events != NULL); + assert(*events != 0); + + FD_ZERO(&rfd); + FD_ZERO(&wfd); + + if (*events & EV_READ) + { + FD_SET(sock, &rfd); + } + + if (*events & EV_WRITE) + { + FD_SET(sock, &wfd); + } + + memset(&tv, 0, sizeof(tv)); + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout - (tv.tv_sec * 1000)) * 1000; + + *events = 0; + + rv = select(sock+1, &rfd, &wfd, NULL, &tv); + + if (rv < 0) + { + return rv; + } + + if (FD_ISSET(sock, &wfd)) + { + *events = EV_WRITE; + } + + if (FD_ISSET(sock, &rfd)) + { + *events = EV_READ; + } + + return rv; +} + +void SocketSetNonblocking(int sock, int nb) +{ +#if defined(_WIN32) + unsigned int yes = nb; + ioctlsocket(s, FIONBIO, &yes); +#else + int flags = fcntl(sock, F_GETFL, 0); + if (nb) + flags |= O_NONBLOCK; + else + flags &= ~O_NONBLOCK; + fcntl(sock, F_SETFL, flags); +#endif +} + +int SocketGetOpt(int sock, int level, int name, void *val, int *len) +{ +#if defined(_WIN32) + return getsockopt(sock, level, name, (char *) val, len); +#else + socklen_t _len = *len; + int rc = getsockopt(sock, level, name, val, &_len); + *len = _len; + return rc; +#endif +} + +int SocketGetError(int sock, int *error) +{ + int len = sizeof(*error); + return SocketGetOpt(sock, SOL_SOCKET, SO_ERROR, error, &len); +} + +int SocketWouldBlock(int error) +{ +#if defined(_WIN32) + return error == WSAEWOULDBLOCK; +#else + return error == EWOULDBLOCK || error == EAGAIN; +#endif +} + +/**********************************************************************/ /* stream.h */ /**********************************************************************/ @@ -4560,9 +4838,11 @@ int StreamClose(Stream *stream); int64_t StreamRead(void *ptr, size_t size, Stream *stream); int64_t StreamReadUint16Be(uint16_t *v, Stream *stream); +int64_t StreamReadByte(unsigned char *byte, Stream *stream); int64_t StreamWrite(const void *ptr, size_t size, Stream *stream); int64_t StreamWriteUint16Be(uint16_t v, Stream *stream); +int64_t StreamWriteByte(unsigned char byte, Stream *stream); int StreamSeek(Stream *stream, int64_t offset, int whence); @@ -4621,6 +4901,11 @@ int64_t StreamReadUint16Be(uint16_t *v, Stream *stream) return 2; } +int64_t StreamReadByte(unsigned char *byte, Stream *stream) +{ + return StreamRead(byte, sizeof(*byte), stream); +} + int64_t StreamWrite(const void *ptr, size_t size, Stream *stream) { STREAM_CHECK_OP(stream, write); @@ -4639,6 +4924,11 @@ int64_t StreamWriteUint16Be(uint16_t v, Stream *stream) return StreamWrite(data, sizeof(data), stream); } +int64_t StreamWriteByte(unsigned char byte, Stream *stream) +{ + return StreamWrite(&byte, sizeof(byte), stream); +} + int StreamSeek(Stream *stream, int64_t offset, int whence) { STREAM_CHECK_OP(stream, seek); @@ -4746,315 +5036,248 @@ int SocketStreamOpen(SocketStream *stream, int sock) } /**********************************************************************/ -/* stream_mqtt.h */ +/* stringstream.h */ /**********************************************************************/ -#ifndef STREAM_MQTT_H -#define STREAM_MQTT_H +#ifndef STRINGSTREAM_H +#define STRINGSTREAM_H +#include <stdio.h> +typedef struct StringStream StringStream; -int64_t StreamReadMqttString(bstring *buf, Stream *stream); -int64_t StreamWriteMqttString(const_bstring buf, Stream *stream); +struct StringStream +{ + Stream base; + bstring buffer; + int64_t pos; +}; -int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream); -int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream); +int StringStreamInit(StringStream *stream); + +int StringStreamInitFromBstring(StringStream *stream, bstring buffer); #endif /**********************************************************************/ -/* stream_mqtt.c */ +/* stringstream.c */ /**********************************************************************/ -#include <string.h> +#include <assert.h> -int64_t StreamReadMqttString(bstring *buf, Stream *stream) +static int StringStreamClose(Stream *base) { - uint16_t len; - bstring result; - - if (StreamReadUint16Be(&len, stream) == -1) - return -1; - - /* We need 1 extra byte for a NULL terminator. bfromcstralloc doesn't do - any size snapping. */ - result = bfromcstralloc(len+1, ""); + StringStream *ss = (StringStream *) base; + bdestroy(ss->buffer); + ss->buffer = NULL; + return 0; +} - if (!result) - return -1; +static int64_t StringStreamRead(void *ptr, size_t size, Stream *stream) +{ + StringStream *ss = (StringStream *) stream; + int64_t available = blength(ss->buffer) - ss->pos; + void *bufptr; - if (StreamRead(bdata(result), len, stream) == -1) + if (available <= 0) { - bdestroy(result); return -1; } - result->slen = len; - result->data[len] = '\0'; + if (size > (size_t) available) + size = available; - *buf = result; + /* Use a temp buffer pointer to make some warnings disappear when using + GCC */ + bufptr = bdataofs(ss->buffer, ss->pos); + memcpy(ptr, bufptr, size); - return len+2; + ss->pos += size; + + return size; } -int64_t StreamWriteMqttString(const_bstring buf, Stream *stream) +static int64_t StringStreamWrite(const void *ptr, size_t size, Stream *stream) { - if (StreamWriteUint16Be(blength(buf), stream) == -1) + StringStream *ss = (StringStream *) stream; + struct tagbstring buf; + if (ss->buffer->mlen <= 0) return -1; - - if (StreamWrite(bdata(buf), blength(buf), stream) == -1) - return -1; - - return 2 + blength(buf); + btfromblk(buf, ptr, size); + bsetstr(ss->buffer, ss->pos, &buf, '\0'); + ss->pos += size; + return size; } -int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream) +int StringStreamSeek(Stream *base, int64_t offset, int whence) { - size_t multiplier = 1; - unsigned char encodedByte; - *remainingLength = 0; - do + StringStream *ss = (StringStream *) base; + int64_t newpos = 0; + + if (whence == SEEK_SET) { - if (StreamRead(&encodedByte, 1, stream) != 1) - return -1; - *remainingLength += (encodedByte & 127) * multiplier; - if (multiplier > 128*128*128) - return -1; - multiplier *= 128; + newpos = offset; } - while ((encodedByte & 128) != 0); - return 0; -} - -int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream) -{ - size_t nbytes = 0; - do + else if (whence == SEEK_CUR) { - unsigned char encodedByte = remainingLength % 128; - remainingLength /= 128; - if (remainingLength > 0) - encodedByte |= 128; - if (StreamWrite(&encodedByte, 1, stream) != 1) - return -1; - ++nbytes; + newpos = ss->pos + offset; + } + else if (whence == SEEK_END) + { + newpos = blength(ss->buffer) - offset; + } + else + { + return -1; } - while (remainingLength > 0); - return nbytes; -} - -/**********************************************************************/ -/* socket.h */ -/**********************************************************************/ - -#ifndef SOCKET_H -#define SOCKET_H + if (newpos > blength(ss->buffer)) + return -1; -#include <stdlib.h> -#include <stdint.h> + if (newpos < 0) + return -1; -int SocketConnect(const char *host, short port); + ss->pos = newpos; -int SocketDisconnect(int sock); + return 0; +} -int SocketSendAll(int sock, const char *buf, size_t *len); +int64_t StringStreamTell(Stream *base) +{ + StringStream *ss = (StringStream *) base; + return ss->pos; +} -enum +static const StreamOps StringStreamOps = { - EV_READ = 1, - EV_WRITE = 2 + StringStreamRead, + StringStreamWrite, + StringStreamClose, + StringStreamSeek, + StringStreamTell }; -int SocketSelect(int sock, int *events, int timeout); - -int64_t SocketRecv(int sock, void *buf, size_t len, int flags); - -int64_t SocketSend(int sock, const void *buf, size_t len, int flags); +int StringStreamInit(StringStream *stream) +{ + assert(stream != NULL); + memset(stream, 0, sizeof(*stream)); + stream->pos = 0; + stream->buffer = bfromcstr(""); + stream->base.ops = &StringStreamOps; + return 0; +} -#endif +int StringStreamInitFromBstring(StringStream *stream, bstring buffer) +{ + assert(stream != NULL); + memset(stream, 0, sizeof(*stream)); + stream->pos = 0; + stream->buffer = buffer; + stream->base.ops = &StringStreamOps; + return 0; +} /**********************************************************************/ -/* socket.c */ +/* stream_mqtt.h */ /**********************************************************************/ +#ifndef STREAM_MQTT_H +#define STREAM_MQTT_H -#include <string.h> -#include <stdio.h> -#include <assert.h> -#if defined(_WIN32) -#else -#include <sys/types.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/select.h> -#include <netdb.h> -#include <unistd.h> -#include <arpa/inet.h> -#endif -#if defined(_WIN32) -static int InitializeWsa() -{ - WSADATA wsa; - int rc; - if ((rc = WSAStartup(MAKEWORD(2, 2), &wsa)) != 0) - { - LOG_ERROR("WSAStartup failed: %d", rc); - return -1; - } - return 0; -} - -#define close closesocket -#endif +int64_t StreamReadMqttString(bstring *buf, Stream *stream); +int64_t StreamWriteMqttString(const_bstring buf, Stream *stream); -int SocketConnect(const char *host, short port) -{ - struct addrinfo hints, *servinfo, *p = NULL; - int rv; - char portstr[6]; - int sock; +int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul, + Stream *stream); +int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream); -#if defined(_WIN32) - if (InitializeWsa() != 0) - { - return -1; - } #endif - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - - assert(snprintf(portstr, sizeof(portstr), "%hu", port) < (int) sizeof(portstr)); +/**********************************************************************/ +/* stream_mqtt.c */ +/**********************************************************************/ - if ((rv = getaddrinfo(host, portstr, &hints, &servinfo)) != 0) - { - goto cleanup; - } - for (p = servinfo; p != NULL; p = p->ai_next) - { - if ((sock = socket(p->ai_family, p->ai_socktype, - p->ai_protocol)) == -1) - { - continue; - } +#include <string.h> - if (connect(sock, p->ai_addr, p->ai_addrlen) == -1) - { - close(sock); - continue; - } +int64_t StreamReadMqttString(bstring *buf, Stream *stream) +{ + uint16_t len; + bstring result; - break; - } + if (StreamReadUint16Be(&len, stream) == -1) + return -1; - freeaddrinfo(servinfo); + /* We need 1 extra byte for a NULL terminator. bfromcstralloc doesn't do + any size snapping. */ + result = bfromcstralloc(len+1, ""); -cleanup: + if (!result) + return -1; - if (p == NULL) + if (StreamRead(bdata(result), len, stream) == -1) { -#if defined(_WIN32) - WSACleanup(); -#endif + bdestroy(result); return -1; } - return sock; -} - -int SocketDisconnect(int sock) -{ - int rc = close(sock); -#if defined(_WIN32) - WSACleanup(); -#endif - return rc; -} + result->slen = len; + result->data[len] = '\0'; -int64_t SocketRecv(int sock, void *buf, size_t len, int flags) -{ - return recv(sock, buf, len, flags); -} + *buf = result; -int64_t SocketSend(int sock, const void *buf, size_t len, int flags) -{ - return send(sock, buf, len, flags); + return len+2; } -int SocketSendAll(int sock, const char *buf, size_t *len) +int64_t StreamWriteMqttString(const_bstring buf, Stream *stream) { - size_t total = 0; - int rv; - size_t remaining = *len; - - while (remaining > 0) - { - if ((rv = send(sock, buf+total, remaining, 0)) == -1) - { - break; - } - total += rv; - remaining -= rv; - } + if (StreamWriteUint16Be(blength(buf), stream) == -1) + return -1; - *len = total; + if (StreamWrite(bdata(buf), blength(buf), stream) == -1) + return -1; - return rv == -1 ? -1 : 0; + return 2 + blength(buf); } -int SocketSelect(int sock, int *events, int timeout) +int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul, + Stream *stream) { - fd_set rfd, wfd; - struct timeval tv; - int rv; - - assert(sock != -1); - assert(events != NULL); - assert(*events != 0); - - FD_ZERO(&rfd); - FD_ZERO(&wfd); - - if (*events & EV_READ) - { - FD_SET(sock, &rfd); - } - - if (*events & EV_WRITE) - { - FD_SET(sock, &wfd); - } - - memset(&tv, 0, sizeof(tv)); - tv.tv_sec = timeout / 1000; - tv.tv_usec = (timeout - (tv.tv_sec * 1000)) * 1000; - - *events = 0; - - rv = select(sock+1, &rfd, &wfd, NULL, &tv); - - if (rv < 0) - { - return rv; - } - - if (FD_ISSET(sock, &wfd)) + unsigned char encodedByte; + do { - *events = EV_WRITE; + if (StreamRead(&encodedByte, 1, stream) != 1) + return -1; + *remainingLength += (encodedByte & 127) * (*mul); + if ((*mul) > 128*128*128) + return -1; + (*mul) *= 128; } + while ((encodedByte & 128) != 0); + *mul = 0; + return 0; +} - if (FD_ISSET(sock, &rfd)) +int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream) +{ + do { - *events = EV_READ; + size_t tmp = *remainingLength; + unsigned char encodedByte = tmp % 128; + tmp /= 128; + if (tmp > 0) + encodedByte |= 128; + if (StreamWrite(&encodedByte, 1, stream) != 1) + { + return -1; + } + *remainingLength = tmp; } - - return rv; + while (*remainingLength > 0); + return 0; } /**********************************************************************/ @@ -5089,87 +5312,35 @@ enum MqttPacketTypeDisconnect = 0xE }; +enum MqttPacketState +{ + MqttPacketStateReadType, + MqttPacketStateReadRemainingLength, + MqttPacketStateReadPayload, + MqttPacketStateReadComplete, + + MqttPacketStateWriteType, + MqttPacketStateWriteRemainingLength, + MqttPacketStateWritePayload, + MqttPacketStateWriteComplete +}; + +struct MqttMessage; + typedef struct MqttPacket MqttPacket; struct MqttPacket { int type; - uint16_t id; - int state; int flags; - int64_t sentAt; + int state; + uint16_t id; + size_t remainingLength; + size_t remainingLengthMul; + /* TODO: maybe switch to have a StringStream here? */ + bstring payload; + struct MqttMessage *message; SIMPLEQ_ENTRY(MqttPacket) sendQueue; - TAILQ_ENTRY(MqttPacket) messages; -}; - -#define MqttPacketType(packet) (((MqttPacket *) (packet))->type) - -#define MqttPacketId(packet) (((MqttPacket *) (packet))->id) - -#define MqttPacketSentAt(packet) (((MqttPacket *) (packet))->sentAt) - -typedef struct MqttPacketConnect MqttPacketConnect; - -struct MqttPacketConnect -{ - MqttPacket base; - char connectFlags; - uint16_t keepAlive; - bstring clientId; - bstring willTopic; - bstring willMessage; - bstring userName; - bstring password; -}; - -typedef struct MqttPacketConnAck MqttPacketConnAck; - -struct MqttPacketConnAck -{ - MqttPacket base; - unsigned char connAckFlags; - unsigned char returnCode; -}; - -typedef struct MqttPacketPublish MqttPacketPublish; - -struct MqttPacketPublish -{ - MqttPacket base; - bstring topicName; - bstring message; - char qos; - char dup; - char retain; -}; - -#define MqttPacketPublishQos(p) (((MqttPacketPublish *) p)->qos) -#define MqttPacketPublishDup(p) (((MqttPacketPublish *) p)->dup) -#define MqttPacketPublishRetain(p) (((MqttPacketPublish *) p)->retain) - -typedef struct MqttPacketSubscribe MqttPacketSubscribe; - -struct MqttPacketSubscribe -{ - MqttPacket base; - struct bstrList *topicFilters; - int *qos; -}; - -typedef struct MqttPacketSubAck MqttPacketSubAck; - -struct MqttPacketSubAck -{ - MqttPacket base; - unsigned char *returnCode; -}; - -typedef struct MqttPacketUnsubscribe MqttPacketUnsubscribe; - -struct MqttPacketUnsubscribe -{ - MqttPacket base; - bstring topicFilter; }; const char *MqttPacketName(int type); @@ -5180,8 +5351,6 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id); void MqttPacketFree(MqttPacket *packet); -int MqttPacketHasId(const MqttPacket *packet); - #endif /**********************************************************************/ @@ -5216,42 +5385,16 @@ const char *MqttPacketName(int type) } } -static MQTT_INLINE size_t MqttPacketStructSize(int type) -{ - switch (type) - { - case MqttPacketTypeConnect: return sizeof(MqttPacketConnect); - case MqttPacketTypeConnAck: return sizeof(MqttPacketConnAck); - case MqttPacketTypePublish: return sizeof(MqttPacketPublish); - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: return sizeof(MqttPacket); - case MqttPacketTypeSubscribe: return sizeof(MqttPacketSubscribe); - case MqttPacketTypeSubAck: return sizeof(MqttPacketSubAck); - case MqttPacketTypeUnsubscribe: return sizeof(MqttPacketUnsubscribe); - case MqttPacketTypeUnsubAck: return sizeof(MqttPacket); - case MqttPacketTypePingReq: return sizeof(MqttPacket); - case MqttPacketTypePingResp: return sizeof(MqttPacket); - case MqttPacketTypeDisconnect: return sizeof(MqttPacket); - default: return (size_t) -1; - } -} - MqttPacket *MqttPacketNew(int type) { MqttPacket *packet = NULL; - packet = (MqttPacket *) calloc(1, MqttPacketStructSize(type)); + packet = (MqttPacket *) calloc(1, sizeof(*packet)); if (!packet) return NULL; packet->type = type; - /* this will make sure that TAILQ_PREV does not segfault if a message - has not been added to a list at any point */ - packet->messages.tqe_prev = &packet->messages.tqe_next; - return packet; } @@ -5266,697 +5409,63 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id) void MqttPacketFree(MqttPacket *packet) { - if (MqttPacketType(packet) == MqttPacketTypeConnect) - { - MqttPacketConnect *p = (MqttPacketConnect *) packet; - bdestroy(p->clientId); - bdestroy(p->willTopic); - bdestroy(p->willMessage); - bdestroy(p->userName); - bdestroy(p->password); - } - else if (MqttPacketType(packet) == MqttPacketTypePublish) - { - MqttPacketPublish *p = (MqttPacketPublish *) packet; - bdestroy(p->topicName); - bdestroy(p->message); - } - else if (MqttPacketType(packet) == MqttPacketTypeSubscribe) - { - MqttPacketSubscribe *p = (MqttPacketSubscribe *) packet; - bstrListDestroy(p->topicFilters); - } - else if (MqttPacketType(packet) == MqttPacketTypeUnsubscribe) - { - MqttPacketUnsubscribe *p = (MqttPacketUnsubscribe *) packet; - bdestroy(p->topicFilter); - } + bdestroy(packet->payload); free(packet); } -int MqttPacketHasId(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypePublish: - return MqttPacketPublishQos(packet) > 0; - - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: - case MqttPacketTypeSubscribe: - case MqttPacketTypeSubAck: - case MqttPacketTypeUnsubscribe: - case MqttPacketTypeUnsubAck: - return 1; - - default: - return 0; - } -} - -/**********************************************************************/ -/* serialize.h */ -/**********************************************************************/ - -#ifndef SERIALIZE_H -#define SERIALIZE_H - - -typedef struct MqttPacket MqttPacket; -typedef struct Stream Stream; - -int MqttPacketSerialize(const MqttPacket *packet, Stream *stream); - -#endif - /**********************************************************************/ -/* serialize.c */ +/* message.h */ /**********************************************************************/ +#ifndef MESSAGE_H +#define MESSAGE_H +#include <stdint.h> -#include <stdlib.h> -#include <assert.h> - -typedef int (*MqttPacketSerializeFunc)(const MqttPacket *packet, - Stream *stream); - -static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); -static const char MqttProtocolLevel = 0x04; - -static MQTT_INLINE size_t MqttStringLengthSerialized(const_bstring s) -{ - return 2 + blength(s); -} - -static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packet) -{ - size_t remainingLength = 0; - - remainingLength += MqttStringLengthSerialized(&MqttProtocolId) + 1 + 1 + 2; - - remainingLength += MqttStringLengthSerialized(packet->clientId); - - if (packet->connectFlags & 0x80) - remainingLength += MqttStringLengthSerialized(packet->userName); - - if (packet->connectFlags & 0x40) - remainingLength += MqttStringLengthSerialized(packet->password); - - if (packet->connectFlags & 0x04) - remainingLength += MqttStringLengthSerialized(packet->willTopic) + - MqttStringLengthSerialized(packet->willMessage); - - return remainingLength; -} - -static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet) -{ - size_t remaining = 2; - int i; - - for (i = 0; i < packet->topicFilters->qty; ++i) - { - remaining += MqttStringLengthSerialized(packet->topicFilters->entry[i]); - remaining += 1; - } - - return remaining; -} - -static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet) -{ - return 2 + MqttStringLengthSerialized(packet->topicFilter); -} - -static size_t MqttPacketPublishGetRemainingLength(const MqttPacketPublish *packet) -{ - size_t remainingLength = 0; - - remainingLength += MqttStringLengthSerialized(packet->topicName); - - /* Packet id */ - if (MqttPacketPublishQos(packet) == 1 || MqttPacketPublishQos(packet) == 2) - { - remainingLength += 2; - } - - remainingLength += blength(packet->message); - - return remainingLength; -} - -static size_t MqttPacketGetRemainingLength(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypeConnect: - return MqttPacketConnectGetRemainingLength( - (MqttPacketConnect *) packet); - - case MqttPacketTypeSubscribe: - return MqttPacketSubscribeGetRemainingLength( - (MqttPacketSubscribe *) packet); - - case MqttPacketTypePublish: - return MqttPacketPublishGetRemainingLength( - (MqttPacketPublish *) packet); - - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: - return 2; - - case MqttPacketTypeUnsubscribe: - return MqttPacketUnsubscribeGetRemainingLength( - (MqttPacketUnsubscribe *) packet); - - default: - return 0; - } -} - -static int MqttPacketFlags(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypePublish: - return ((MqttPacketPublishDup(packet) & 1) << 3) | - ((MqttPacketPublishQos(packet) & 3) << 1) | - (MqttPacketPublishRetain(packet) & 1); - - case MqttPacketTypePubRel: - case MqttPacketTypeSubscribe: - case MqttPacketTypeUnsubscribe: - return 0x2; - - default: - return 0; - } -} - -static int MqttPacketBaseSerialize(const MqttPacket *packet, Stream *stream) -{ - unsigned char typeAndFlags; - size_t remainingLength; - - typeAndFlags = ((packet->type & 0x0F) << 4) | - (MqttPacketFlags(packet) & 0x0F); - remainingLength = MqttPacketGetRemainingLength(packet); - - LOG_DEBUG("type:%02X (%s) flags:%02X", packet->type, - MqttPacketName(packet->type), MqttPacketFlags(packet)); - - if (StreamWrite(&typeAndFlags, 1, stream) != 1) - return -1; - - if (StreamWriteRemainingLength(remainingLength, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketWithIdSerialize(const MqttPacket *packet, Stream *stream) -{ - assert(MqttPacketHasId((const MqttPacket *) packet)); - - if (MqttPacketBaseSerialize(packet, stream) == -1) - return -1; - - if (StreamWriteUint16Be(packet->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *stream) -{ - if (MqttPacketBaseSerialize(&packet->base, stream) == -1) - return -1; - - if (StreamWriteMqttString(&MqttProtocolId, stream) == -1) - return -1; - - if (StreamWrite(&MqttProtocolLevel, 1, stream) != 1) - return -1; - - if (StreamWrite(&packet->connectFlags, 1, stream) != 1) - return -1; - - if (StreamWriteUint16Be(packet->keepAlive, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->clientId, stream) == -1) - return -1; - - if (packet->connectFlags & 0x04) - { - if (StreamWriteMqttString(packet->willTopic, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->willMessage, stream) == -1) - return -1; - } - - if (packet->connectFlags & 0x80) - { - if (StreamWriteMqttString(packet->userName, stream) == -1) - return -1; - } - - if (packet->connectFlags & 0x40) - { - if (StreamWriteMqttString(packet->password, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream) -{ - int i; - - if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - for (i = 0; i < packet->topicFilters->qty; ++i) - { - unsigned char qos = (unsigned char) packet->qos[i]; - - if (StreamWriteMqttString(packet->topicFilters->entry[i], stream) == -1) - return -1; - - if (StreamWrite(&qos, 1, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketUnsubscribeSerialize(const MqttPacketUnsubscribe *packet, Stream *stream) -{ - if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->topicFilter, stream) == -1) - return -1; - - return 0; -} -static int MqttPacketPublishSerialize(const MqttPacketPublish *packet, Stream *stream) +enum MqttMessageState { - if (MqttPacketBaseSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->topicName, stream) == -1) - return -1; - - LOG_DEBUG("qos:%d", MqttPacketPublishQos(packet)); - - if (MqttPacketPublishQos(packet) > 0) - { - if (StreamWriteUint16Be(packet->base.id, stream) == -1) - return -1; - } - - if (StreamWrite(bdata(packet->message), blength(packet->message), stream) == -1) - return -1; + MqttMessageStateQueued, + MqttMessageStatePublish, + MqttMessageStateWaitPubAck, + MqttMessageStateWaitPubRec, + MqttMessageStateWaitPubComp, + MqttMessageStateWaitPubRel +}; - return 0; -} +typedef struct MqttMessage MqttMessage; -int MqttPacketSerialize(const MqttPacket *packet, Stream *stream) +struct MqttMessage { - MqttPacketSerializeFunc f = NULL; - - switch (packet->type) - { - case MqttPacketTypeConnect: - f = (MqttPacketSerializeFunc) MqttPacketConnectSerialize; - break; - - case MqttPacketTypeConnAck: - break; - - case MqttPacketTypePublish: - f = (MqttPacketSerializeFunc) MqttPacketPublishSerialize; - break; - - case MqttPacketTypePubAck: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubRec: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubRel: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubComp: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypeSubscribe: - f = (MqttPacketSerializeFunc) MqttPacketSubscribeSerialize; - break; - - case MqttPacketTypeSubAck: - break; - - case MqttPacketTypeUnsubscribe: - f = (MqttPacketSerializeFunc) MqttPacketUnsubscribeSerialize; - break; - - case MqttPacketTypeUnsubAck: - break; - - case MqttPacketTypePingReq: - f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; - break; - - case MqttPacketTypePingResp: - break; - - case MqttPacketTypeDisconnect: - f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; - break; - - default: - return -1; - } - - assert(f != NULL && "no serializer"); - - return f(packet, stream); -} - -/**********************************************************************/ -/* deserialize.h */ -/**********************************************************************/ - -#ifndef DESERIALIZE_H -#define DESERIALIZE_H - + int state; + int qos; + int retain; + int dup; + int padding; + uint16_t id; + int64_t timestamp; + bstring topic; + bstring payload; + TAILQ_ENTRY(MqttMessage) chain; +}; -typedef struct MqttPacket MqttPacket; -typedef struct Stream Stream; +typedef struct MqttMessageList MqttMessageList; +TAILQ_HEAD(MqttMessageList, MqttMessage); -int MqttPacketDeserialize(MqttPacket **packet, Stream *stream); +void MqttMessageFree(MqttMessage *msg); #endif /**********************************************************************/ -/* deserialize.c */ +/* message.c */ /**********************************************************************/ -#include <stdlib.h> -#include <assert.h> - -typedef int (*MqttPacketDeserializeFunc)(MqttPacket **packet, Stream *stream); - -static int MqttPacketWithIdDeserialize(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamReadUint16Be(&(*packet)->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stream) +void MqttMessageFree(MqttMessage *msg) { - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamRead(&(*packet)->connAckFlags, 1, stream) != 1) - return -1; - - if (StreamRead(&(*packet)->returnCode, 1, stream) != 1) - return -1; - - return 0; -} - -static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream) -{ - size_t remainingLength = 0; - size_t i; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) - return -1; - - remainingLength -= 2; - - (*packet)->returnCode = (unsigned char *) malloc( - sizeof(*(*packet)->returnCode) * remainingLength); - - for (i = 0; i < remainingLength; ++i) - { - if (StreamRead(&((*packet)->returnCode[i]), 1, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketTypeUnsubAckDeserialize(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamReadUint16Be(&(*packet)->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketPublishDeserialize(MqttPacketPublish **packet, Stream *stream) -{ - size_t remainingLength = 0; - size_t payloadSize = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (StreamReadMqttString(&(*packet)->topicName, stream) == -1) - return -1; - - LOG_DEBUG("remainingLength:%lu", remainingLength); - - payloadSize = remainingLength - blength((*packet)->topicName) - 2; - - LOG_DEBUG("qos:%d payloadSize:%lu", MqttPacketPublishQos(*packet), - payloadSize); - - if (MqttPacketHasId((const MqttPacket *) *packet)) - { - LOG_DEBUG("packet has id"); - payloadSize -= 2; - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) - { - return -1; - } - } - - LOG_DEBUG("reading payload payloadSize:%lu\n", payloadSize); - - /* Allocate extra byte for a NULL terminator. If the user tries to print - the payload directly. */ - - (*packet)->message = bfromcstralloc(payloadSize+1, ""); - - if (StreamRead(bdata((*packet)->message), payloadSize, stream) == -1) - return -1; - - (*packet)->message->slen = payloadSize; - (*packet)->message->data[payloadSize] = '\0'; - - return 0; -} - -static int MqttPacketGenericDeserializer(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - char buffer[256]; - - (void) packet; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - while (remainingLength > 0) - { - size_t l = sizeof(buffer); - - if (remainingLength < l) - l = remainingLength; - - if (StreamRead(buffer, l, stream) != (int64_t) l) - return -1; - - remainingLength -= l; - } - - return 0; -} - -static int ValidateFlags(int type, int flags) -{ - int rv = 0; - - switch (type) - { - case MqttPacketTypePublish: - { - int qos = (flags >> 1) & 2; - if (qos >= 0 && qos <= 2) - rv = 1; - break; - } - - case MqttPacketTypePubRel: - case MqttPacketTypeSubscribe: - case MqttPacketTypeUnsubscribe: - if (flags == 2) - { - rv = 1; - } - break; - - default: - if (flags == 0) - { - rv = 1; - } - break; - } - - return rv; -} - -int MqttPacketDeserialize(MqttPacket **packet, Stream *stream) -{ - MqttPacketDeserializeFunc deserializer = NULL; - char typeAndFlags; - int type; - int flags; - int rv; - - if (StreamRead(&typeAndFlags, 1, stream) != 1) - return -1; - - type = (typeAndFlags & 0xF0) >> 4; - flags = (typeAndFlags & 0x0F); - - if (!ValidateFlags(type, flags)) - { - return -1; - } - - switch (type) - { - case MqttPacketTypeConnect: - break; - - case MqttPacketTypeConnAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketConnAckDeserialize; - break; - - case MqttPacketTypePublish: - deserializer = (MqttPacketDeserializeFunc) MqttPacketPublishDeserialize; - break; - - case MqttPacketTypePubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubRec: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubRel: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubComp: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypeSubscribe: - break; - - case MqttPacketTypeSubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketSubAckDeserialize; - break; - - case MqttPacketTypeUnsubscribe: - break; - - case MqttPacketTypeUnsubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketTypeUnsubAckDeserialize; - break; - - case MqttPacketTypePingReq: - break; - - case MqttPacketTypePingResp: - break; - - case MqttPacketTypeDisconnect: - break; - - default: - return -1; - } - - if (!deserializer) - { - deserializer = MqttPacketGenericDeserializer; - } - - *packet = MqttPacketNew(type); - - if (!*packet) - return -1; - - if (type == MqttPacketTypePublish) - { - MqttPacketPublishDup(*packet) = (flags >> 3) & 1; - MqttPacketPublishQos(*packet) = (flags >> 1) & 3; - MqttPacketPublishRetain(*packet) = flags & 1; - } - - rv = deserializer(packet, stream); - - return rv; + bdestroy(msg->topic); + bdestroy(msg->payload); + free(msg); } /**********************************************************************/ @@ -5978,8 +5487,14 @@ int MqttPacketDeserialize(MqttPacket **packet, Stream *stream) #error define PRId64 for your platform #endif -TAILQ_HEAD(MessageList, MqttPacket); -typedef struct MessageList MessageList; +typedef enum MqttClientState MqttClientState; + +enum MqttClientState +{ + MqttClientStateDisconnected, + MqttClientStateConnecting, + MqttClientStateConnected, +}; struct MqttClient { @@ -6009,9 +5524,9 @@ struct MqttClient /* packets waiting to be sent over network */ SIMPLEQ_HEAD(, MqttPacket) sendQueue; /* sent messages that are not done yet */ - MessageList outMessages; + MqttMessageList outMessages; /* received messages that are not done yet */ - MessageList inMessages; + MqttMessageList inMessages; int sessionPresent; /* when was the last packet sent */ int64_t lastPacketSentTime; @@ -6031,18 +5546,16 @@ struct MqttClient int willRetain; /* 1 if client should ignore incoming PUBLISH messages, 0 handle them */ int paused; -}; - -enum MessageState -{ - MessageStateQueued = 100, - MessageStateSend, - MessageStateSent + bstring userName; + bstring password; + /* The packet we are receiving */ + MqttPacket inPacket; + MqttClientState state; }; static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); static int MqttClientQueueSimplePacket(MqttClient *client, int type); -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet); +static int MqttClientSendPacket(MqttClient *client); static int MqttClientRecvPacket(MqttClient *client); static uint16_t MqttClientNextPacketId(MqttClient *client); static void MqttClientProcessMessageQueue(MqttClient *client); @@ -6050,14 +5563,14 @@ static void MqttClientClearQueues(MqttClient *client); static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int queued = 0; int inMessagesCount = 0; int outMessagesCount = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (packet->state == MessageStateQueued) + if (msg->state == MqttMessageStateQueued) { ++queued; } @@ -6065,7 +5578,7 @@ static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) ++outMessagesCount; } - TAILQ_FOREACH(packet, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { ++inMessagesCount; } @@ -6093,6 +5606,8 @@ MqttClient *MqttClientNew(const char *clientId) client->maxQueued = 0; client->maxInflight = 20; + client->state = MqttClientStateDisconnected; + TAILQ_INIT(&client->outMessages); TAILQ_INIT(&client->inMessages); SIMPLEQ_INIT(&client->sendQueue); @@ -6108,6 +5623,8 @@ void MqttClientFree(MqttClient *client) bdestroy(client->willTopic); bdestroy(client->willMessage); bdestroy(client->host); + bdestroy(client->userName); + bdestroy(client->password); if (client->stream.sock != -1) { @@ -6163,15 +5680,54 @@ void MqttClientSetOnPublish(MqttClient *client, client->onPublish = cb; } +static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); +static const char MqttProtocolLevel = 0x04; + +static unsigned char MqttClientConnectFlags(MqttClient *client) +{ + unsigned char connectFlags = 0; + + if (client->cleanSession) + { + connectFlags |= 0x02; + } + + if (client->willTopic) + { + connectFlags |= 0x04; + connectFlags |= (client->willQos & 3) << 3; + connectFlags |= (client->willRetain & 1) << 5; + } + + if (client->userName) + { + connectFlags |= 0x80; + if (client->password) + { + connectFlags |= 0x40; + } + } + + return connectFlags; +} + int MqttClientConnect(MqttClient *client, const char *host, short port, int keepAlive, int cleanSession) { int sock; - MqttPacketConnect *packet; + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(host != NULL); + if (client->state != MqttClientStateDisconnected) + { + LOG_ERROR("client must be disconnected to connect"); + return -1; + } + if (client->host) bassigncstr(client->host, host); else @@ -6193,10 +5749,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, LOG_DEBUG("connecting"); - if ((sock = SocketConnect(host, port)) == -1) + if ((sock = SocketConnect(host, port, 1)) == -1) { - LOG_ERROR("SocketConnect failed!"); - return -1; + if (SocketErrno != SOCKET_EINPROGRESS) + { + LOG_ERROR("SocketConnect failed!"); + return -1; + } } if (SocketStreamOpen(&client->stream, sock) == -1) @@ -6204,32 +5763,39 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, return -1; } - packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect); + packet = MqttPacketNew(MqttPacketTypeConnect); if (!packet) return -1; - if (client->cleanSession) - { - packet->connectFlags |= 0x02; - } - - packet->keepAlive = client->keepAlive; + StringStreamInit(&ss); - packet->clientId = bstrcpy(client->clientId); + StreamWriteMqttString(&MqttProtocolId, pss); + StreamWriteByte(MqttProtocolLevel, pss); + StreamWriteByte(MqttClientConnectFlags(client), pss); + StreamWriteUint16Be(client->keepAlive, pss); + StreamWriteMqttString(client->clientId, pss); if (client->willTopic) { - packet->connectFlags |= 0x04; - - packet->willTopic = bstrcpy(client->willTopic); - packet->willMessage = bstrcpy(client->willMessage); + StreamWriteMqttString(client->willTopic, pss); + StreamWriteMqttString(client->willMessage, pss); + } - packet->connectFlags |= (client->willQos & 3) << 3; - packet->connectFlags |= (client->willRetain & 1) << 5; + if (client->userName) + { + StreamWriteMqttString(client->userName, pss); + if(client->password) + { + StreamWriteMqttString(client->password, pss); + } } - MqttClientQueuePacket(client, &packet->base); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + client->state = MqttClientStateConnecting; return 0; } @@ -6242,13 +5808,14 @@ int MqttClientDisconnect(MqttClient *client) int MqttClientIsConnected(MqttClient *client) { - return client->stream.sock != -1; + return client->stream.sock != -1 && + client->state == MqttClientStateConnected; } int MqttClientRunOnce(MqttClient *client, int timeout) { int rv; - int events; + int events = 0; assert(client != NULL); @@ -6258,19 +5825,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout) return -1; } - events = EV_READ; + if (client->state == MqttClientStateConnected) + { + events = EV_READ; - /* Handle outMessages and inMessages, moving queued messages to sendQueue - if there are less than maxInflight number of messages in flight */ - MqttClientProcessMessageQueue(client); + /* Handle outMessages and inMessages, moving queued messages to sendQueue + if there are less than maxInflight number of messages in flight */ + MqttClientProcessMessageQueue(client); - if (SIMPLEQ_EMPTY(&client->sendQueue)) + if (SIMPLEQ_EMPTY(&client->sendQueue)) + { + LOG_DEBUG("nothing to write"); + } + else + { + events |= EV_WRITE; + } + } + else if (client->state == MqttClientStateConnecting) { - LOG_DEBUG("nothing to write"); + events = EV_WRITE; } else { - events |= EV_WRITE; + LOG_ERROR("not connected"); + return -1; } LOG_DEBUG("selecting"); @@ -6278,6 +5857,14 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (timeout < 0) { timeout = client->keepAlive * 1000; + if (timeout == 0) + { + timeout = 30 * 1000; + } + } + else if (timeout > (client->keepAlive * 1000) && client->keepAlive > 0) + { + timeout = client->keepAlive * 1000; } rv = SocketSelect(client->stream.sock, &events, timeout); @@ -6293,22 +5880,26 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (events & EV_WRITE) { - MqttPacket *packet; - LOG_DEBUG("socket writable"); - packet = SIMPLEQ_FIRST(&client->sendQueue); - - if (packet) + if (client->state == MqttClientStateConnecting) { - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (MqttClientSendPacket(client, packet) == -1) + int sockError; + SocketGetError(client->stream.sock, &sockError); + LOG_DEBUG("sockError: %d", sockError); + if (sockError == 0) { - LOG_ERROR("MqttClientSendPacket failed"); - client->stopped = 1; + LOG_DEBUG("connected!"); + client->state = MqttClientStateConnected; + return 0; } } + + if (MqttClientSendPacket(client) == -1) + { + LOG_ERROR("MqttClientSendPacket failed"); + client->stopped = 1; + } } if (events & EV_READ) @@ -6372,10 +5963,12 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, } int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count) + int *qos, size_t count) { - MqttPacketSubscribe *packet = NULL; + MqttPacket *packet = NULL; size_t i; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(topicFilters != NULL); @@ -6383,68 +5976,122 @@ int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, assert(qos != NULL); assert(count > 0); - packet = (MqttPacketSubscribe *) MqttPacketWithIdNew( - MqttPacketTypeSubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeSubscribe, + MqttClientNextPacketId(client)); if (!packet) return -1; - packet->topicFilters = bstrListCreate(); - bstrListAllocMin(packet->topicFilters, count); + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(packet->id, pss); - packet->qos = (int *) malloc(sizeof(int) * count); + LOG_DEBUG("SUBSCRIBE id:%d", (int) packet->id); for (i = 0; i < count; ++i) { - packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]); - ++packet->topicFilters->qty; + struct tagbstring filter; + btfromcstr(filter, topicFilters[i]); + StreamWriteMqttString(&filter, pss); + StreamWriteByte(qos[i] & 3, pss); } - memcpy(packet->qos, qos, sizeof(int) * count); + packet->payload = ss.buffer; - MqttClientQueuePacket(client, (MqttPacket *) packet); - - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + MqttClientQueuePacket(client, packet); - return MqttPacketId(packet); + return packet->id; } int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter) { - MqttPacketUnsubscribe *packet = NULL; + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + struct tagbstring filter; assert(client != NULL); assert(topicFilter != NULL); - packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew( - MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeUnsubscribe, + MqttClientNextPacketId(client)); + + if (!packet) + return -1; - packet->topicFilter = bfromcstr(topicFilter); + packet->flags = 0x02; - MqttClientQueuePacket(client, (MqttPacket *) packet); + StringStreamInit(&ss); - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + StreamWriteUint16Be(packet->id, pss); - return MqttPacketId(packet); + btfromcstr(filter, topicFilter); + + StreamWriteMqttString(&filter, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return packet->id; } static MQTT_INLINE int MqttClientOutMessagesLen(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int count = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { ++count; } return count; } +static MqttPacket *PublishToPacket(MqttMessage *msg) +{ + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + + if (msg->qos > 0) + { + packet = MqttPacketWithIdNew(MqttPacketTypePublish, + msg->id); + } + else + { + packet = MqttPacketNew(MqttPacketTypePublish); + } + + if (!packet) + return NULL; + + packet->message = msg; + + StringStreamInit(&ss); + + StreamWriteMqttString(msg->topic, pss); + + if (msg->qos > 0) + { + StreamWriteUint16Be(msg->id, pss); + } + + StreamWrite(bdata(msg->payload), blength(msg->payload), pss); + + packet->payload = ss.buffer; + packet->flags = (msg->qos & 3) << 1; + packet->flags |= msg->retain & 1; + + return packet; +} + int MqttClientPublish(MqttClient *client, int qos, int retain, const char *topic, const void *data, size_t size) { - MqttPacketPublish *packet; - - assert(client != NULL); + MqttMessage *message; /* first check if the queue is already full */ if (qos > 0 && client->maxQueued > 0 && @@ -6453,55 +6100,55 @@ int MqttClientPublish(MqttClient *client, int qos, int retain, return -1; } - if (qos > 0) + message = calloc(1, sizeof(*message)); + if (!message) { - packet = (MqttPacketPublish *) MqttPacketWithIdNew( - MqttPacketTypePublish, MqttClientNextPacketId(client)); + return -1; } - else + + message->state = MqttMessageStateQueued; + message->qos = qos; + message->retain = retain; + message->dup = 0; + message->timestamp = MqttGetCurrentTime(); + + if (qos == 0) { - packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish); - } + /* Copy payload and topic directly from user buffers as we don't need + to keep the message data around after this function. */ + MqttPacket *packet; + struct tagbstring bttopic, btpayload; - if (!packet) - return -1; + btfromcstr(bttopic, topic); + message->topic = &bttopic; - packet->qos = qos; - packet->retain = retain; - packet->topicName = bfromcstr(topic); - packet->message = blk2bstr(data, size); + btfromblk(btpayload, data, size); + message->payload = &btpayload; - if (qos > 0) - { - /* check how many messages there are coming in and going out currently - that are not yet done */ - if (client->maxInflight == 0 || - MqttClientInflightMessageCount(client) < client->maxInflight) - { - LOG_DEBUG("setting message (%d) state to MessageStateSend", - MqttPacketId(packet)); - packet->base.state = MessageStateSend; - } - else - { - LOG_DEBUG("setting message (%d) state to MessageStateQueued", - MqttPacketId(packet)); - packet->base.state = MessageStateQueued; - } + packet = PublishToPacket(message); + + message->topic = NULL; + message->payload = NULL; - /* add the message to the outMessages queue to wait for processing */ - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, - messages); + MqttClientQueuePacket(client, packet); + + MqttMessageFree(message); + + return 0; } else { - MqttClientQueuePacket(client, (MqttPacket *) packet); - } + /* Duplicate the user buffers as we need the data to be available + longer. */ + message->topic = bfromcstr(topic); + message->payload = blk2bstr(data, size); - if (qos > 0) - return MqttPacketId(packet); + message->id = MqttClientNextPacketId(client); - return 0; + TAILQ_INSERT_TAIL(&client->outMessages, message, chain); + + return message->id; + } } int MqttClientPublishCString(MqttClient *client, int qos, int retain, @@ -6547,10 +6194,54 @@ int MqttClientSetWill(MqttClient *client, const char *topic, const void *msg, return 0; } +int MqttClientSetAuth(MqttClient *client, const char *userName, + const char *password) +{ + assert(client != NULL); + + if (client->state == MqttClientStateConnecting) + { + LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect"); + return -1; + } + + if (userName) + { + if (client->userName) + bassigncstr(client->userName, userName); + else + client->userName = bfromcstr(userName); + + if (password) + { + if (client->password) + bassigncstr(client->password, password); + else + client->password = bfromcstr(password); + } + else + { + bdestroy(client->password); + client->password = NULL; + } + } + else + { + bdestroy(client->userName); + client->userName = NULL; + + bdestroy(client->password); + client->password = NULL; + } + + return 0; +} + static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet) { assert(client != NULL); LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type)); + packet->state = MqttPacketStateWriteType; SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue); } @@ -6563,128 +6254,363 @@ static int MqttClientQueueSimplePacket(MqttClient *client, int type) return 0; } -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet) +static int MqttClientSendPacket(MqttClient *client) { - if (MqttPacketSerialize(packet, &client->stream.base) == -1) - return -1; + MqttPacket *packet; - packet->sentAt = MqttGetCurrentTime(); - client->lastPacketSentTime = packet->sentAt; + packet = SIMPLEQ_FIRST(&client->sendQueue); - if (packet->type == MqttPacketTypeDisconnect) + if (!packet) { - client->stopped = 1; + LOG_WARNING("MqttClientSendPacket called with no queued packets"); + return 0; } - /* If the packet is not on any message list, it can be removed after - sending. */ - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) + while (packet != NULL) { - LOG_DEBUG("freeing packet %s after sending", - MqttPacketName(MqttPacketType(packet))); - MqttPacketFree(packet); + switch (packet->state) + { + case MqttPacketStateWriteType: + { + unsigned char typeAndFlags = ((packet->type & 0x0F) << 4) | + (packet->flags & 0x0F); + + if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + return -1; + } + + packet->state = MqttPacketStateWriteRemainingLength; + packet->remainingLength = blength(packet->payload); + + break; + } + + case MqttPacketStateWriteRemainingLength: + { + if (StreamWriteRemainingLength(&packet->remainingLength, + &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + return -1; + } + + packet->state = MqttPacketStateWritePayload; + packet->remainingLength = blength(packet->payload); + + break; + } + + case MqttPacketStateWritePayload: + { + if (packet->payload) + { + int64_t offset = blength(packet->payload) - packet->remainingLength; + int64_t nwritten = 0; + int towrite = 16*1024; + + if (packet->remainingLength < 16*1024) + towrite = packet->remainingLength; + + nwritten = StreamWrite(bdataofs(packet->payload, offset), + towrite, + &client->stream.base); + + if (nwritten == -1) + { + if (SocketWouldBlock(SocketErrno)) + { + return 0; + } + return -1; + } + + packet->remainingLength -= nwritten; + + LOG_DEBUG("nwritten:%d", (int) nwritten); + } + + if (packet->remainingLength == 0) + { + LOG_DEBUG("packet payload sent"); + packet->state = MqttPacketStateWriteComplete; + } + + break; + } + + case MqttPacketStateWriteComplete: + { + client->lastPacketSentTime = MqttGetCurrentTime(); + + if (packet->type == MqttPacketTypeDisconnect) + { + client->stopped = 1; + client->state = MqttClientStateDisconnected; + } + + LOG_DEBUG("sent %s", MqttPacketName(packet->type)); + + if (packet->type == MqttPacketTypePublish && packet->message) + { + MqttMessage *msg = packet->message; + + if (msg->qos == 1) + { + msg->state = MqttMessageStateWaitPubAck; + } + else if (msg->qos == 2) + { + msg->state = MqttMessageStateWaitPubRec; + } + } + + if (packet->message) + { + packet->message->timestamp = client->lastPacketSentTime; + } + + SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); + + MqttPacketFree(packet); + + packet = SIMPLEQ_FIRST(&client->sendQueue); + + break; + } + } } return 0; } -static void MqttClientHandleConnAck(MqttClient *client, - MqttPacketConnAck *packet) +static int MqttClientHandleConnAck(MqttClient *client) { - client->sessionPresent = packet->connAckFlags & 1; + StringStream ss; + Stream *pss = (Stream *) &ss; + unsigned char flags; + unsigned char rc; + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadByte(&flags, pss); + + StreamReadByte(&rc, pss); + + client->sessionPresent = flags & 1; LOG_DEBUG("sessionPresent:%d", client->sessionPresent); if (client->onConnect) { - LOG_DEBUG("calling onConnect rc:%d", packet->returnCode); - client->onConnect(client, packet->returnCode, client->sessionPresent); + LOG_DEBUG("calling onConnect rc:%d", rc); + client->onConnect(client, rc, client->sessionPresent); } + + return 0; } -static void MqttClientHandlePingResp(MqttClient *client) +static int MqttClientHandlePingResp(MqttClient *client) { LOG_DEBUG("got ping response"); client->pingSent = 0; + return 0; } -static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet) +static int MqttClientHandleSubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + int *qos; + StringStream ss; + Stream *pss = (Stream *) &ss; + int count; + int i; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + LOG_DEBUG("received SUBACK with id:%d", (int) id); + + count = blength(client->inPacket.payload) - StreamTell(pss); + + if (count <= 0) { - if (MqttPacketType(sub) == MqttPacketTypeSubscribe && - MqttPacketId(sub) == MqttPacketId(packet)) - { - break; - } + LOG_ERROR("number of return codes invalid"); + return -1; } - if (!sub) + qos = malloc(count * sizeof(int)); + + for (i = 0; i < count; ++i) { - LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + unsigned char byte; + StreamReadByte(&byte, pss); + qos[i] = byte; } - else + + if (client->onSubscribe) { - if (client->onSubscribe) - { - MqttPacketSubscribe *sub2; - int i; + client->onSubscribe(client, id, qos, count); + } - sub2 = (MqttPacketSubscribe *) sub; + free(qos); - for (i = 0; i < sub2->topicFilters->qty; ++i) - { - const char *filter = bdata(sub2->topicFilters->entry[i]); - int rc = packet->returnCode[i]; + return 0; +} - LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d", - MqttPacketId(packet), filter, rc); +static int MqttClientSendPubAck(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; - client->onSubscribe(client, MqttPacketId(packet), filter, rc); - } - } + packet = MqttPacketWithIdNew(MqttPacketTypePubAck, id); - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - } + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; } -static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet) +static int MqttClientSendPubRec(MqttClient *client, MqttMessage *msg) { + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRec, msg->id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRel(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRel, msg->id); + + if (!packet) + return -1; + + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubComp(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubComp, id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientHandlePublish(MqttClient *client) +{ + MqttMessage *msg; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; + MqttPacket *packet; + int qos; + int retain; + bstring topic; + void *payload; + int payloadSize; + + /* We are paused - do nothing */ if (client->paused) - return; + return 0; - if (MqttPacketPublishQos(packet) == 2) + packet = &client->inPacket; + + qos = (packet->flags >> 1) & 3; + retain = packet->flags & 1; + + StringStreamInitFromBstring(&ss, packet->payload); + + StreamReadMqttString(&topic, pss); + + if (qos > 0) + { + StreamReadUint16Be(&id, pss); + } + + payload = bdataofs(ss.buffer, ss.pos); + payloadSize = blength(ss.buffer) - ss.pos; + + if (qos == 2) { /* Check if we have sent a PUBREC previously with the same id. If we have, we have to resend the PUBREC. We must not call the onMessage callback again. */ - MqttPacket *pubRec; - - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (pubRec) + if (msg) { - LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet)); - MqttClientQueuePacket(client, pubRec); - return; + LOG_DEBUG("resending PUBREC id:%u", msg->id); + MqttClientSendPubRec(client, msg); + bdestroy(topic); + return 0; } } @@ -6692,268 +6618,395 @@ static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packe { LOG_DEBUG("calling onMessage"); client->onMessage(client, - bdata(packet->topicName), - bdata(packet->message), - blength(packet->message), - packet->qos, - packet->retain); + bdata(topic), + payload, + payloadSize, + qos, + retain); } - if (MqttPacketPublishQos(packet) > 0) + bdestroy(topic); + + if (qos == 1) + { + MqttClientSendPubAck(client, id); + } + else if (qos == 2) { - int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck : - MqttPacketTypePubRec; + msg = calloc(1, sizeof(*msg)); - MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet)); + msg->state = MqttMessageStateWaitPubRel; + msg->id = id; + msg->qos = qos; - if (MqttPacketPublishQos(packet) == 2) - { - /* append to inMessages as we need a reply to this response */ - TAILQ_INSERT_TAIL(&client->inMessages, resp, messages); - } + TAILQ_INSERT_TAIL(&client->inMessages, msg, chain); - MqttClientQueuePacket(client, resp); + MqttClientSendPubRec(client, msg); } + + return 0; } -static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubAck(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); - TAILQ_FOREACH(pub, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubAck) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); - if (client->onPublish) - { - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + if (client->onPublish) + { + client->onPublish(client, msg->id); } + + MqttMessageFree(msg); + + return 0; } -static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRec(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + /* Also check if we are waiting for PUBCOMP, if we have sent PUBREL but + they haven't received it. */ + if (msg->id == id && + (msg->state == MqttMessageStateWaitPubRec || + msg->state == MqttMessageStateWaitPubComp)) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubRel; - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); + msg->state = MqttMessageStateWaitPubComp; - pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet)); - pubRel->state = MessageStateSend; + bdestroy(msg->payload); + msg->payload = NULL; - TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages); - } + bdestroy(msg->topic); + msg->topic = NULL; + + if (MqttClientSendPubRel(client, msg) == -1) + return -1; + + return 0; } -static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRel(MqttClient *client) { - MqttPacket *pubRec; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (!pubRec) + if (!msg) { - MqttPacket *pubComp; - - TAILQ_FOREACH(pubComp, &client->inMessages, messages) - { - if (MqttPacketId(pubComp) == MqttPacketId(packet) && - MqttPacketType(pubComp) == MqttPacketTypePubComp) - { - break; - } - } - - if (pubComp) - { - MqttClientQueuePacket(client, pubComp); - } - else - { - LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; - } + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubComp; - - TAILQ_REMOVE(&client->inMessages, pubRec, messages); - MqttPacketFree(pubRec); - pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp, - MqttPacketId(packet)); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); - TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages); + if (MqttClientSendPubComp(client, id) == -1) + return -1; - MqttClientQueuePacket(client, pubComp); - } + return 0; } -static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubComp(MqttClient *client) { - MqttPacket *pubRel; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); + + StringStreamInitFromBstring(&ss, client->inPacket.payload); - TAILQ_FOREACH(pubRel, &client->outMessages, messages) + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pubRel) == MqttPacketId(packet) && - MqttPacketType(pubRel) == MqttPacketTypePubRel) + if (msg->id == id && msg->state == MqttMessageStateWaitPubComp) { break; } } - if (!pubRel) + if (!msg) { - LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_WARNING("no message found with id %d", (int) id); + return 0; } - else - { - TAILQ_REMOVE(&client->outMessages, pubRel, messages); - MqttPacketFree(pubRel); - if (client->onPublish) - { - LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet)); - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + MqttMessageFree(msg); + + if (client->onPublish) + { + LOG_DEBUG("calling onPublish id:%d", id); + client->onPublish(client, id); } + + return 0; } -static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandleUnsubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) - { - if (MqttPacketId(sub) == MqttPacketId(packet) && - MqttPacketType(sub) == MqttPacketTypeUnsubscribe) - { - break; - } - } + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); - if (!sub) + if (client->onUnsubscribe) { - LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + client->onUnsubscribe(client, id); } - else - { - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - if (client->onUnsubscribe) - { - LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet)); - client->onUnsubscribe(client, MqttPacketId(packet)); - } - } + return 0; } -static int MqttClientRecvPacket(MqttClient *client) +static int MqttClientHandlePacket(MqttClient *client) { - MqttPacket *packet = NULL; - - if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1) - return -1; - - LOG_DEBUG("received packet %s", MqttPacketName(packet->type)); + int rc; - switch (MqttPacketType(packet)) + switch (client->inPacket.type) { case MqttPacketTypeConnAck: - MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet); + rc = MqttClientHandleConnAck(client); break; case MqttPacketTypePingResp: - MqttClientHandlePingResp(client); + rc = MqttClientHandlePingResp(client); break; case MqttPacketTypeSubAck: - MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet); + rc = MqttClientHandleSubAck(client); break; - case MqttPacketTypePublish: - MqttClientHandlePublish(client, (MqttPacketPublish *) packet); + case MqttPacketTypeUnsubAck: + rc = MqttClientHandleUnsubAck(client); break; case MqttPacketTypePubAck: - MqttClientHandlePubAck(client, packet); + rc = MqttClientHandlePubAck(client); break; case MqttPacketTypePubRec: - MqttClientHandlePubRec(client, packet); + rc = MqttClientHandlePubRec(client); break; - case MqttPacketTypePubRel: - MqttClientHandlePubRel(client, packet); + case MqttPacketTypePubComp: + rc = MqttClientHandlePubComp(client); break; - case MqttPacketTypePubComp: - MqttClientHandlePubComp(client, packet); + case MqttPacketTypePubRel: + rc = MqttClientHandlePubRel(client); break; - case MqttPacketTypeUnsubAck: - MqttClientHandleUnsubAck(client, packet); + case MqttPacketTypePublish: + rc = MqttClientHandlePublish(client); break; default: - LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet)); + LOG_ERROR("packet not handled yet"); + rc = -1; break; } - MqttPacketFree(packet); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + + client->inPacket.state = MqttPacketStateReadType; + + return rc; +} + +static int MqttClientRecvPacket(MqttClient *client) +{ + while (1) + { + switch (client->inPacket.state) + { + case MqttPacketStateReadType: + { + unsigned char typeAndFlags; + + if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet type"); + return -1; + } + + client->inPacket.type = typeAndFlags >> 4; + client->inPacket.flags = typeAndFlags & 0x0F; + + if (client->inPacket.type < MqttPacketTypeConnect || + client->inPacket.type > MqttPacketTypeDisconnect) + { + LOG_ERROR("unknown packet type: %d", client->inPacket.type); + return -1; + } + + client->inPacket.state = MqttPacketStateReadRemainingLength; + client->inPacket.remainingLength = 0; + client->inPacket.remainingLengthMul = 1; + client->inPacket.payload = NULL; + + break; + } + + case MqttPacketStateReadRemainingLength: + { + if (StreamReadRemainingLength(&client->inPacket.remainingLength, + &client->inPacket.remainingLengthMul, + &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed to read remaining length"); + return -1; + } + + LOG_DEBUG("remainingLength:%lu", + client->inPacket.remainingLength); + + client->inPacket.state = MqttPacketStateReadPayload; + + break; + } + + case MqttPacketStateReadPayload: + { + if (client->inPacket.remainingLength > 0) + { + int64_t nread, offset, toread; + + if (client->inPacket.payload == NULL) + { + unsigned char *data; + client->inPacket.payload = bfromcstr(""); + ballocmin(client->inPacket.payload, + client->inPacket.remainingLength+1); + data = client->inPacket.payload->data; + data[client->inPacket.remainingLength] = '\0'; + } + + offset = blength(client->inPacket.payload); + + toread = 16*1024; + + if (client->inPacket.remainingLength < (size_t) toread) + toread = client->inPacket.remainingLength; + + nread = StreamRead(bdataofs(client->inPacket.payload, + offset), + toread, + &client->stream.base); + + if (nread == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet payload"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + else if (nread == 0) + { + LOG_ERROR("socket disconnected"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + + client->inPacket.remainingLength -= nread; + client->inPacket.payload->slen += nread; + + LOG_DEBUG("nread:%d", (int) nread); + } + + if (client->inPacket.remainingLength == 0) + { + client->inPacket.state = MqttPacketStateReadComplete; + } + break; + } + + case MqttPacketStateReadComplete: + { + int type = client->inPacket.type; + LOG_DEBUG("received %s", MqttPacketName(type)); + return MqttClientHandlePacket(client); + } + } + } return 0; } @@ -6968,101 +7021,89 @@ static uint16_t MqttClientNextPacketId(MqttClient *client) return id; } -static int64_t MqttPacketTimeSinceSent(MqttPacket *packet) +static int64_t MqttMessageTimeSinceSent(MqttMessage *msg) { int64_t now = MqttGetCurrentTime(); - return now - packet->sentAt; + return now - msg->timestamp; } -static void MqttClientProcessInMessages(MqttClient *client) +static int MqttMessageShouldResend(MqttClient *client, MqttMessage *msg) { - MqttPacket *packet, *next; - - LOG_DEBUG("processing inMessages"); - - TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next) + if (msg->timestamp > 0 && + MqttMessageTimeSinceSent(msg) >= client->retryTimeout*1000) { - LOG_DEBUG("packet type:%s id:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - - if (MqttPacketType(packet) == MqttPacketTypePubComp) - { - int64_t elapsed = MqttPacketTimeSinceSent(packet); - if (packet->sentAt > 0 && - elapsed >= client->retryTimeout*1000) - { - LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64, - MqttPacketId(packet), elapsed); - - TAILQ_REMOVE(&client->inMessages, packet, messages); - - MqttPacketFree(packet); - } - } + return 1; } + + return 0; } -static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet) +static void MqttClientProcessInMessages(MqttClient *client) { - if (packet->sentAt > 0 && - MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000) + MqttMessage *msg, *next; + + TAILQ_FOREACH_SAFE(msg, &client->inMessages, chain, next) { - return 1; - } + switch (msg->state) + { + case MqttMessageStateWaitPubRel: + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRec(client, msg); + } + break; - return 0; + default: + break; + } + } } static void MqttClientProcessOutMessages(MqttClient *client) { - MqttPacket *packet, *next; + MqttMessage *msg, *next; + MqttPacket *packet; int inflight = MqttClientInflightMessageCount(client); - LOG_DEBUG("processing outMessages inflight:%d", inflight); - - TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next) + TAILQ_FOREACH_SAFE(msg, &client->outMessages, chain, next) { - LOG_DEBUG("packet type:%s id:%d state:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet), - packet->state); - - switch (packet->state) + switch (msg->state) { - case MessageStateQueued: + case MqttMessageStateQueued: + { if (inflight >= client->maxInflight) { - LOG_DEBUG("cannot dequeue %s/%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - break; - } - else - { - /* If there's less than maxInflight messages currently - inflight, we can dequeue some messages by falling - through to MessageStateSend. */ - LOG_DEBUG("dequeuing %s (%d)", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - ++inflight; + continue; } - - case MessageStateSend: - packet->state = MessageStateSent; + /* State change from MqttMessageStatePublish happens after + the packet has been sent (in MqttClientSendPacket). */ + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); MqttClientQueuePacket(client, packet); + ++inflight; break; + } - case MessageStateSent: - if (MqttPacketShouldResend(client, packet)) + case MqttMessageStateWaitPubAck: + case MqttMessageStateWaitPubRec: + { + if (MqttMessageShouldResend(client, msg)) { - packet->state = MessageStateSend; + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); + MqttClientQueuePacket(client, packet); } break; + } - default: + case MqttMessageStateWaitPubComp: + { + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRel(client, msg); + } break; + } } } } @@ -7078,30 +7119,22 @@ static void MqttClientClearQueues(MqttClient *client) while (!SIMPLEQ_EMPTY(&client->sendQueue)) { MqttPacket *packet = SIMPLEQ_FIRST(&client->sendQueue); - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) - { - MqttPacketFree(packet); - } + MqttPacketFree(packet); } while (!TAILQ_EMPTY(&client->outMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->outMessages); - TAILQ_REMOVE(&client->outMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->outMessages); + TAILQ_REMOVE(&client->outMessages, msg, chain); + MqttMessageFree(msg); } while (!TAILQ_EMPTY(&client->inMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->inMessages); - TAILQ_REMOVE(&client->inMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->inMessages); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); } } diff --git a/amalgamation/mqtt.h b/amalgamation/mqtt.h index ad84aaf..840026e 100644 --- a/amalgamation/mqtt.h +++ b/amalgamation/mqtt.h @@ -33,8 +33,8 @@ typedef void (*MqttClientOnConnectCallback)(MqttClient *client, typedef void (*MqttClientOnSubscribeCallback)(MqttClient *client, int id, - const char *topicFilter, - MqttSubscriptionStatus status); + int *qos, + int count); typedef void (*MqttClientOnUnsubscribeCallback)(MqttClient *client, int id); @@ -84,7 +84,7 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, int qos); int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count); + int *qos, size_t count); int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter); @@ -103,6 +103,9 @@ void MqttClientSetMaxQueuedMessages(MqttClient *client, int max); int MqttClientSetWill(MqttClient *client, const char *topic, const void *msg, size_t size, int qos, int retain); +int MqttClientSetAuth(MqttClient *client, const char *username, + const char *password); + #if defined(__cplusplus) } #endif |
