diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-18 09:17:16 +0200 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-18 09:17:16 +0200 |
| commit | 03f7cae60919a04ff0ebc87baf3b51b9bbb1776f (patch) | |
| tree | 3d0306c4b5f5ddef77e9bcd0ec8cadf3013ba13d | |
| parent | d97c786dbd30b4349d22b41c657f69a335f3d77a (diff) | |
| download | mqtt-03f7cae60919a04ff0ebc87baf3b51b9bbb1776f.tar.gz mqtt-03f7cae60919a04ff0ebc87baf3b51b9bbb1776f.zip | |
Modify the code to use nonblocking sockets
| -rw-r--r-- | src/client.c | 183 | ||||
| -rw-r--r-- | src/packet.h | 2 | ||||
| -rw-r--r-- | src/socket.c | 73 | ||||
| -rw-r--r-- | src/socket.h | 26 | ||||
| -rw-r--r-- | src/stream_mqtt.c | 30 | ||||
| -rw-r--r-- | src/stream_mqtt.h | 5 |
6 files changed, 257 insertions, 62 deletions
diff --git a/src/client.c b/src/client.c index 704a53e..e303fe9 100644 --- a/src/client.c +++ b/src/client.c @@ -26,6 +26,15 @@ #error define PRId64 for your platform #endif +typedef enum MqttClientState MqttClientState; + +enum MqttClientState +{ + MqttClientStateDisconnected, + MqttClientStateConnecting, + MqttClientStateConnected, +}; + struct MqttClient { SocketStream stream; @@ -80,6 +89,7 @@ struct MqttClient bstring password; /* The packet we are receiving */ MqttPacket inPacket; + MqttClientState state; }; static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); @@ -135,6 +145,8 @@ MqttClient *MqttClientNew(const char *clientId) client->maxQueued = 0; client->maxInflight = 20; + client->state = MqttClientStateDisconnected; + TAILQ_INIT(&client->outMessages); TAILQ_INIT(&client->inMessages); SIMPLEQ_INIT(&client->sendQueue); @@ -249,6 +261,12 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, assert(client != NULL); assert(host != NULL); + if (client->state != MqttClientStateDisconnected) + { + LOG_ERROR("client must be disconnected to connect"); + return -1; + } + if (client->host) bassigncstr(client->host, host); else @@ -270,10 +288,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, LOG_DEBUG("connecting"); - if ((sock = SocketConnect(host, port)) == -1) + if ((sock = SocketConnect(host, port, 1)) == -1) { - LOG_ERROR("SocketConnect failed!"); - return -1; + if (SocketErrno != SOCKET_EINPROGRESS) + { + LOG_ERROR("SocketConnect failed!"); + return -1; + } } if (SocketStreamOpen(&client->stream, sock) == -1) @@ -313,6 +334,8 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, MqttClientQueuePacket(client, packet); + client->state = MqttClientStateConnecting; + return 0; } @@ -324,13 +347,14 @@ int MqttClientDisconnect(MqttClient *client) int MqttClientIsConnected(MqttClient *client) { - return client->stream.sock != -1; + return client->stream.sock != -1 && + client->state == MqttClientStateConnected; } int MqttClientRunOnce(MqttClient *client, int timeout) { int rv; - int events; + int events = 0; assert(client != NULL); @@ -340,19 +364,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout) return -1; } - events = EV_READ; + if (client->state == MqttClientStateConnected) + { + events = EV_READ; - /* Handle outMessages and inMessages, moving queued messages to sendQueue - if there are less than maxInflight number of messages in flight */ - MqttClientProcessMessageQueue(client); + /* Handle outMessages and inMessages, moving queued messages to sendQueue + if there are less than maxInflight number of messages in flight */ + MqttClientProcessMessageQueue(client); - if (SIMPLEQ_EMPTY(&client->sendQueue)) + if (SIMPLEQ_EMPTY(&client->sendQueue)) + { + LOG_DEBUG("nothing to write"); + } + else + { + events |= EV_WRITE; + } + } + else if (client->state == MqttClientStateConnecting) { - LOG_DEBUG("nothing to write"); + events = EV_WRITE; } else { - events |= EV_WRITE; + LOG_ERROR("not connected"); + return -1; } LOG_DEBUG("selecting"); @@ -385,6 +421,19 @@ int MqttClientRunOnce(MqttClient *client, int timeout) { LOG_DEBUG("socket writable"); + if (client->state == MqttClientStateConnecting) + { + int sockError; + SocketGetError(client->stream.sock, &sockError); + LOG_DEBUG("sockError: %d", sockError); + if (sockError == 0) + { + LOG_DEBUG("connected!"); + client->state = MqttClientStateConnected; + return 0; + } + } + if (MqttClientSendPacket(client) == -1) { LOG_ERROR("MqttClientSendPacket failed"); @@ -689,7 +738,7 @@ int MqttClientSetAuth(MqttClient *client, const char *userName, { assert(client != NULL); - if (MqttClientIsConnected(client)) + if (client->state == MqttClientStateConnecting) { LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect"); return -1; @@ -767,23 +816,29 @@ static int MqttClientSendPacket(MqttClient *client) if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1) { + if (SocketWouldBlock(SocketErrno)) + return 0; return -1; } packet->state = MqttPacketStateWriteRemainingLength; + packet->remainingLength = blength(packet->payload); break; } case MqttPacketStateWriteRemainingLength: { - if (StreamWriteRemainingLength(blength(packet->payload), + if (StreamWriteRemainingLength(&packet->remainingLength, &client->stream.base) == -1) { + if (SocketWouldBlock(SocketErrno)) + return 0; return -1; } packet->state = MqttPacketStateWritePayload; + packet->remainingLength = blength(packet->payload); break; } @@ -792,15 +847,36 @@ static int MqttClientSendPacket(MqttClient *client) { if (packet->payload) { - if (StreamWrite(bdata(packet->payload), - blength(packet->payload), - &client->stream.base) == -1) + int64_t offset = blength(packet->payload) - packet->remainingLength; + int64_t nwritten = 0; + int towrite = 16*1024; + + if (packet->remainingLength < 16*1024) + towrite = packet->remainingLength; + + nwritten = StreamWrite(bdataofs(packet->payload, offset), + towrite, + &client->stream.base); + + if (nwritten == -1) { + if (SocketWouldBlock(SocketErrno)) + { + return 0; + } return -1; } + + packet->remainingLength -= nwritten; + + LOG_DEBUG("nwritten:%d", (int) nwritten); } - packet->state = MqttPacketStateWriteComplete; + if (packet->remainingLength == 0) + { + LOG_DEBUG("packet payload sent"); + packet->state = MqttPacketStateWriteComplete; + } break; } @@ -812,6 +888,7 @@ static int MqttClientSendPacket(MqttClient *client) if (packet->type == MqttPacketTypeDisconnect) { client->stopped = 1; + client->state = MqttClientStateDisconnected; } LOG_DEBUG("sent %s", MqttPacketName(packet->type)); @@ -1353,11 +1430,12 @@ static int MqttClientRecvPacket(MqttClient *client) case MqttPacketStateReadType: { unsigned char typeAndFlags; - int rc; - if ((rc = StreamReadByte(&typeAndFlags, &client->stream.base)) != 1) + if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1) { - LOG_ERROR("failed reading packet type: %d", rc); + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet type"); return -1; } @@ -1372,6 +1450,9 @@ static int MqttClientRecvPacket(MqttClient *client) } client->inPacket.state = MqttPacketStateReadRemainingLength; + client->inPacket.remainingLength = 0; + client->inPacket.remainingLengthMul = 1; + client->inPacket.payload = NULL; break; } @@ -1379,12 +1460,20 @@ static int MqttClientRecvPacket(MqttClient *client) case MqttPacketStateReadRemainingLength: { if (StreamReadRemainingLength(&client->inPacket.remainingLength, + &client->inPacket.remainingLengthMul, &client->stream.base) == -1) { + if (SocketWouldBlock(SocketErrno)) + return 0; LOG_ERROR("failed to read remaining length"); return -1; } + + LOG_DEBUG("remainingLength:%lu", + client->inPacket.remainingLength); + client->inPacket.state = MqttPacketStateReadPayload; + break; } @@ -1392,21 +1481,57 @@ static int MqttClientRecvPacket(MqttClient *client) { if (client->inPacket.remainingLength > 0) { - client->inPacket.payload = bfromcstr(""); - ballocmin(client->inPacket.payload, - client->inPacket.remainingLength+1); - if (StreamRead(bdata(client->inPacket.payload), - client->inPacket.remainingLength, - &client->stream.base) == -1) + int64_t nread, offset, toread; + + if (client->inPacket.payload == NULL) { + unsigned char *data; + client->inPacket.payload = bfromcstr(""); + ballocmin(client->inPacket.payload, + client->inPacket.remainingLength+1); + data = client->inPacket.payload->data; + data[client->inPacket.remainingLength] = '\0'; + } + + offset = blength(client->inPacket.payload); + + toread = 16*1024; + + if (client->inPacket.remainingLength < (size_t) toread) + toread = client->inPacket.remainingLength; + + nread = StreamRead(bdataofs(client->inPacket.payload, + offset), + toread, + &client->stream.base); + + if (nread == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; LOG_ERROR("failed reading packet payload"); bdestroy(client->inPacket.payload); client->inPacket.payload = NULL; return -1; } - client->inPacket.payload->slen = client->inPacket.remainingLength; + else if (nread == 0) + { + LOG_ERROR("socket disconnected"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + + client->inPacket.remainingLength -= nread; + client->inPacket.payload->slen += nread; + + LOG_DEBUG("nread:%d", (int) nread); + } + + if (client->inPacket.remainingLength == 0) + { + client->inPacket.state = MqttPacketStateReadComplete; } - client->inPacket.state = MqttPacketStateReadComplete; break; } diff --git a/src/packet.h b/src/packet.h index 36dc81f..a5e2ce7 100644 --- a/src/packet.h +++ b/src/packet.h @@ -53,6 +53,8 @@ struct MqttPacket int state; uint16_t id; size_t remainingLength; + size_t remainingLengthMul; + /* TODO: maybe switch to have a StringStream here? */ bstring payload; struct MqttMessage *message; SIMPLEQ_ENTRY(MqttPacket) sendQueue; diff --git a/src/socket.c b/src/socket.c index 64a7c01..b70f4fb 100644 --- a/src/socket.c +++ b/src/socket.c @@ -6,18 +6,6 @@ #include <assert.h> #if defined(_WIN32) -#include "win32.h" -#else -#include <sys/types.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/select.h> -#include <netdb.h> -#include <unistd.h> -#include <arpa/inet.h> -#endif - -#if defined(_WIN32) static int InitializeWsa() { WSADATA wsa; @@ -33,9 +21,9 @@ static int InitializeWsa() #define close closesocket #endif -int SocketConnect(const char *host, short port) +int SocketConnect(const char *host, short port, int nonblocking) { - struct addrinfo hints, *servinfo, *p = NULL; + struct addrinfo hints, *servinfo = NULL, *p = NULL; int rv; char portstr[6]; int sock; @@ -66,8 +54,16 @@ int SocketConnect(const char *host, short port) continue; } + if (nonblocking) + { + SocketSetNonblocking(sock, 1); + } + if (connect(sock, p->ai_addr, p->ai_addrlen) == -1) { + int err = SocketErrno; + if (err == SOCKET_EINPROGRESS) + break; close(sock); continue; } @@ -75,10 +71,13 @@ int SocketConnect(const char *host, short port) break; } - freeaddrinfo(servinfo); - cleanup: + if (servinfo) + { + freeaddrinfo(servinfo); + } + if (p == NULL) { #if defined(_WIN32) @@ -178,3 +177,45 @@ int SocketSelect(int sock, int *events, int timeout) return rv; } + +void SocketSetNonblocking(int sock, int nb) +{ +#if defined(_WIN32) + unsigned int yes = nb; + ioctlsocket(s, FIONBIO, &yes); +#else + int flags = fcntl(sock, F_GETFL, 0); + if (nb) + flags |= O_NONBLOCK; + else + flags &= ~O_NONBLOCK; + fcntl(sock, F_SETFL, flags); +#endif +} + +int SocketGetOpt(int sock, int level, int name, void *val, int *len) +{ +#if defined(_WIN32) + return getsockopt(sock, level, name, (char *) val, len); +#else + socklen_t _len = *len; + int rc = getsockopt(sock, level, name, val, &_len); + *len = _len; + return rc; +#endif +} + +int SocketGetError(int sock, int *error) +{ + int len = sizeof(*error); + return SocketGetOpt(sock, SOL_SOCKET, SO_ERROR, error, &len); +} + +int SocketWouldBlock(int error) +{ +#if defined(_WIN32) + return error == WSAEWOULDBLOCK; +#else + return error == EWOULDBLOCK || error == EAGAIN; +#endif +} diff --git a/src/socket.h b/src/socket.h index e7b1a80..abc67af 100644 --- a/src/socket.h +++ b/src/socket.h @@ -6,7 +6,25 @@ #include <stdlib.h> #include <stdint.h> -int SocketConnect(const char *host, short port); +#if defined(_WIN32) +#include "win32.h" +#define SocketErrno (WSAGetLastError()) +#define SOCKET_EINPROGRESS (WSAEWOULDBLOCK) +#else +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/select.h> +#include <netdb.h> +#include <unistd.h> +#include <arpa/inet.h> +#include <fcntl.h> +#include <errno.h> +#define SocketErrno (errno) +#define SOCKET_EINPROGRESS (EINPROGRESS) +#endif + +int SocketConnect(const char *host, short port, int nonblocking); int SocketDisconnect(int sock); @@ -24,4 +42,10 @@ int64_t SocketRecv(int sock, void *buf, size_t len, int flags); int64_t SocketSend(int sock, const void *buf, size_t len, int flags); +void SocketSetNonblocking(int sock, int nb); + +int SocketGetError(int sock, int *error); + +int SocketWouldBlock(int error); + #endif diff --git a/src/stream_mqtt.c b/src/stream_mqtt.c index 3864ef3..f2bd9cd 100644 --- a/src/stream_mqtt.c +++ b/src/stream_mqtt.c @@ -42,37 +42,39 @@ int64_t StreamWriteMqttString(const_bstring buf, Stream *stream) return 2 + blength(buf); } -int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream) +int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul, + Stream *stream) { - size_t multiplier = 1; unsigned char encodedByte; - *remainingLength = 0; do { if (StreamRead(&encodedByte, 1, stream) != 1) return -1; - *remainingLength += (encodedByte & 127) * multiplier; - if (multiplier > 128*128*128) + *remainingLength += (encodedByte & 127) * (*mul); + if ((*mul) > 128*128*128) return -1; - multiplier *= 128; + (*mul) *= 128; } while ((encodedByte & 128) != 0); + *mul = 0; return 0; } -int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream) +int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream) { - size_t nbytes = 0; do { - unsigned char encodedByte = remainingLength % 128; - remainingLength /= 128; - if (remainingLength > 0) + size_t tmp = *remainingLength; + unsigned char encodedByte = tmp % 128; + tmp /= 128; + if (tmp > 0) encodedByte |= 128; if (StreamWrite(&encodedByte, 1, stream) != 1) + { return -1; - ++nbytes; + } + *remainingLength = tmp; } - while (remainingLength > 0); - return nbytes; + while (*remainingLength > 0); + return 0; } diff --git a/src/stream_mqtt.h b/src/stream_mqtt.h index a128524..8c8ccb5 100644 --- a/src/stream_mqtt.h +++ b/src/stream_mqtt.h @@ -9,7 +9,8 @@ int64_t StreamReadMqttString(bstring *buf, Stream *stream); int64_t StreamWriteMqttString(const_bstring buf, Stream *stream); -int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream); -int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream); +int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul, + Stream *stream); +int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream); #endif |
