aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOskari Timperi <oskari.timperi@iki.fi>2017-03-18 09:17:16 +0200
committerOskari Timperi <oskari.timperi@iki.fi>2017-03-18 09:17:16 +0200
commit03f7cae60919a04ff0ebc87baf3b51b9bbb1776f (patch)
tree3d0306c4b5f5ddef77e9bcd0ec8cadf3013ba13d
parentd97c786dbd30b4349d22b41c657f69a335f3d77a (diff)
downloadmqtt-03f7cae60919a04ff0ebc87baf3b51b9bbb1776f.tar.gz
mqtt-03f7cae60919a04ff0ebc87baf3b51b9bbb1776f.zip
Modify the code to use nonblocking sockets
-rw-r--r--src/client.c183
-rw-r--r--src/packet.h2
-rw-r--r--src/socket.c73
-rw-r--r--src/socket.h26
-rw-r--r--src/stream_mqtt.c30
-rw-r--r--src/stream_mqtt.h5
6 files changed, 257 insertions, 62 deletions
diff --git a/src/client.c b/src/client.c
index 704a53e..e303fe9 100644
--- a/src/client.c
+++ b/src/client.c
@@ -26,6 +26,15 @@
#error define PRId64 for your platform
#endif
+typedef enum MqttClientState MqttClientState;
+
+enum MqttClientState
+{
+ MqttClientStateDisconnected,
+ MqttClientStateConnecting,
+ MqttClientStateConnected,
+};
+
struct MqttClient
{
SocketStream stream;
@@ -80,6 +89,7 @@ struct MqttClient
bstring password;
/* The packet we are receiving */
MqttPacket inPacket;
+ MqttClientState state;
};
static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet);
@@ -135,6 +145,8 @@ MqttClient *MqttClientNew(const char *clientId)
client->maxQueued = 0;
client->maxInflight = 20;
+ client->state = MqttClientStateDisconnected;
+
TAILQ_INIT(&client->outMessages);
TAILQ_INIT(&client->inMessages);
SIMPLEQ_INIT(&client->sendQueue);
@@ -249,6 +261,12 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
assert(client != NULL);
assert(host != NULL);
+ if (client->state != MqttClientStateDisconnected)
+ {
+ LOG_ERROR("client must be disconnected to connect");
+ return -1;
+ }
+
if (client->host)
bassigncstr(client->host, host);
else
@@ -270,10 +288,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
LOG_DEBUG("connecting");
- if ((sock = SocketConnect(host, port)) == -1)
+ if ((sock = SocketConnect(host, port, 1)) == -1)
{
- LOG_ERROR("SocketConnect failed!");
- return -1;
+ if (SocketErrno != SOCKET_EINPROGRESS)
+ {
+ LOG_ERROR("SocketConnect failed!");
+ return -1;
+ }
}
if (SocketStreamOpen(&client->stream, sock) == -1)
@@ -313,6 +334,8 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
MqttClientQueuePacket(client, packet);
+ client->state = MqttClientStateConnecting;
+
return 0;
}
@@ -324,13 +347,14 @@ int MqttClientDisconnect(MqttClient *client)
int MqttClientIsConnected(MqttClient *client)
{
- return client->stream.sock != -1;
+ return client->stream.sock != -1 &&
+ client->state == MqttClientStateConnected;
}
int MqttClientRunOnce(MqttClient *client, int timeout)
{
int rv;
- int events;
+ int events = 0;
assert(client != NULL);
@@ -340,19 +364,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
return -1;
}
- events = EV_READ;
+ if (client->state == MqttClientStateConnected)
+ {
+ events = EV_READ;
- /* Handle outMessages and inMessages, moving queued messages to sendQueue
- if there are less than maxInflight number of messages in flight */
- MqttClientProcessMessageQueue(client);
+ /* Handle outMessages and inMessages, moving queued messages to sendQueue
+ if there are less than maxInflight number of messages in flight */
+ MqttClientProcessMessageQueue(client);
- if (SIMPLEQ_EMPTY(&client->sendQueue))
+ if (SIMPLEQ_EMPTY(&client->sendQueue))
+ {
+ LOG_DEBUG("nothing to write");
+ }
+ else
+ {
+ events |= EV_WRITE;
+ }
+ }
+ else if (client->state == MqttClientStateConnecting)
{
- LOG_DEBUG("nothing to write");
+ events = EV_WRITE;
}
else
{
- events |= EV_WRITE;
+ LOG_ERROR("not connected");
+ return -1;
}
LOG_DEBUG("selecting");
@@ -385,6 +421,19 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
{
LOG_DEBUG("socket writable");
+ if (client->state == MqttClientStateConnecting)
+ {
+ int sockError;
+ SocketGetError(client->stream.sock, &sockError);
+ LOG_DEBUG("sockError: %d", sockError);
+ if (sockError == 0)
+ {
+ LOG_DEBUG("connected!");
+ client->state = MqttClientStateConnected;
+ return 0;
+ }
+ }
+
if (MqttClientSendPacket(client) == -1)
{
LOG_ERROR("MqttClientSendPacket failed");
@@ -689,7 +738,7 @@ int MqttClientSetAuth(MqttClient *client, const char *userName,
{
assert(client != NULL);
- if (MqttClientIsConnected(client))
+ if (client->state == MqttClientStateConnecting)
{
LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect");
return -1;
@@ -767,23 +816,29 @@ static int MqttClientSendPacket(MqttClient *client)
if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1)
{
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
return -1;
}
packet->state = MqttPacketStateWriteRemainingLength;
+ packet->remainingLength = blength(packet->payload);
break;
}
case MqttPacketStateWriteRemainingLength:
{
- if (StreamWriteRemainingLength(blength(packet->payload),
+ if (StreamWriteRemainingLength(&packet->remainingLength,
&client->stream.base) == -1)
{
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
return -1;
}
packet->state = MqttPacketStateWritePayload;
+ packet->remainingLength = blength(packet->payload);
break;
}
@@ -792,15 +847,36 @@ static int MqttClientSendPacket(MqttClient *client)
{
if (packet->payload)
{
- if (StreamWrite(bdata(packet->payload),
- blength(packet->payload),
- &client->stream.base) == -1)
+ int64_t offset = blength(packet->payload) - packet->remainingLength;
+ int64_t nwritten = 0;
+ int towrite = 16*1024;
+
+ if (packet->remainingLength < 16*1024)
+ towrite = packet->remainingLength;
+
+ nwritten = StreamWrite(bdataofs(packet->payload, offset),
+ towrite,
+ &client->stream.base);
+
+ if (nwritten == -1)
{
+ if (SocketWouldBlock(SocketErrno))
+ {
+ return 0;
+ }
return -1;
}
+
+ packet->remainingLength -= nwritten;
+
+ LOG_DEBUG("nwritten:%d", (int) nwritten);
}
- packet->state = MqttPacketStateWriteComplete;
+ if (packet->remainingLength == 0)
+ {
+ LOG_DEBUG("packet payload sent");
+ packet->state = MqttPacketStateWriteComplete;
+ }
break;
}
@@ -812,6 +888,7 @@ static int MqttClientSendPacket(MqttClient *client)
if (packet->type == MqttPacketTypeDisconnect)
{
client->stopped = 1;
+ client->state = MqttClientStateDisconnected;
}
LOG_DEBUG("sent %s", MqttPacketName(packet->type));
@@ -1353,11 +1430,12 @@ static int MqttClientRecvPacket(MqttClient *client)
case MqttPacketStateReadType:
{
unsigned char typeAndFlags;
- int rc;
- if ((rc = StreamReadByte(&typeAndFlags, &client->stream.base)) != 1)
+ if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1)
{
- LOG_ERROR("failed reading packet type: %d", rc);
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
+ LOG_ERROR("failed reading packet type");
return -1;
}
@@ -1372,6 +1450,9 @@ static int MqttClientRecvPacket(MqttClient *client)
}
client->inPacket.state = MqttPacketStateReadRemainingLength;
+ client->inPacket.remainingLength = 0;
+ client->inPacket.remainingLengthMul = 1;
+ client->inPacket.payload = NULL;
break;
}
@@ -1379,12 +1460,20 @@ static int MqttClientRecvPacket(MqttClient *client)
case MqttPacketStateReadRemainingLength:
{
if (StreamReadRemainingLength(&client->inPacket.remainingLength,
+ &client->inPacket.remainingLengthMul,
&client->stream.base) == -1)
{
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
LOG_ERROR("failed to read remaining length");
return -1;
}
+
+ LOG_DEBUG("remainingLength:%lu",
+ client->inPacket.remainingLength);
+
client->inPacket.state = MqttPacketStateReadPayload;
+
break;
}
@@ -1392,21 +1481,57 @@ static int MqttClientRecvPacket(MqttClient *client)
{
if (client->inPacket.remainingLength > 0)
{
- client->inPacket.payload = bfromcstr("");
- ballocmin(client->inPacket.payload,
- client->inPacket.remainingLength+1);
- if (StreamRead(bdata(client->inPacket.payload),
- client->inPacket.remainingLength,
- &client->stream.base) == -1)
+ int64_t nread, offset, toread;
+
+ if (client->inPacket.payload == NULL)
{
+ unsigned char *data;
+ client->inPacket.payload = bfromcstr("");
+ ballocmin(client->inPacket.payload,
+ client->inPacket.remainingLength+1);
+ data = client->inPacket.payload->data;
+ data[client->inPacket.remainingLength] = '\0';
+ }
+
+ offset = blength(client->inPacket.payload);
+
+ toread = 16*1024;
+
+ if (client->inPacket.remainingLength < (size_t) toread)
+ toread = client->inPacket.remainingLength;
+
+ nread = StreamRead(bdataofs(client->inPacket.payload,
+ offset),
+ toread,
+ &client->stream.base);
+
+ if (nread == -1)
+ {
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
LOG_ERROR("failed reading packet payload");
bdestroy(client->inPacket.payload);
client->inPacket.payload = NULL;
return -1;
}
- client->inPacket.payload->slen = client->inPacket.remainingLength;
+ else if (nread == 0)
+ {
+ LOG_ERROR("socket disconnected");
+ bdestroy(client->inPacket.payload);
+ client->inPacket.payload = NULL;
+ return -1;
+ }
+
+ client->inPacket.remainingLength -= nread;
+ client->inPacket.payload->slen += nread;
+
+ LOG_DEBUG("nread:%d", (int) nread);
+ }
+
+ if (client->inPacket.remainingLength == 0)
+ {
+ client->inPacket.state = MqttPacketStateReadComplete;
}
- client->inPacket.state = MqttPacketStateReadComplete;
break;
}
diff --git a/src/packet.h b/src/packet.h
index 36dc81f..a5e2ce7 100644
--- a/src/packet.h
+++ b/src/packet.h
@@ -53,6 +53,8 @@ struct MqttPacket
int state;
uint16_t id;
size_t remainingLength;
+ size_t remainingLengthMul;
+ /* TODO: maybe switch to have a StringStream here? */
bstring payload;
struct MqttMessage *message;
SIMPLEQ_ENTRY(MqttPacket) sendQueue;
diff --git a/src/socket.c b/src/socket.c
index 64a7c01..b70f4fb 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -6,18 +6,6 @@
#include <assert.h>
#if defined(_WIN32)
-#include "win32.h"
-#else
-#include <sys/types.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <sys/select.h>
-#include <netdb.h>
-#include <unistd.h>
-#include <arpa/inet.h>
-#endif
-
-#if defined(_WIN32)
static int InitializeWsa()
{
WSADATA wsa;
@@ -33,9 +21,9 @@ static int InitializeWsa()
#define close closesocket
#endif
-int SocketConnect(const char *host, short port)
+int SocketConnect(const char *host, short port, int nonblocking)
{
- struct addrinfo hints, *servinfo, *p = NULL;
+ struct addrinfo hints, *servinfo = NULL, *p = NULL;
int rv;
char portstr[6];
int sock;
@@ -66,8 +54,16 @@ int SocketConnect(const char *host, short port)
continue;
}
+ if (nonblocking)
+ {
+ SocketSetNonblocking(sock, 1);
+ }
+
if (connect(sock, p->ai_addr, p->ai_addrlen) == -1)
{
+ int err = SocketErrno;
+ if (err == SOCKET_EINPROGRESS)
+ break;
close(sock);
continue;
}
@@ -75,10 +71,13 @@ int SocketConnect(const char *host, short port)
break;
}
- freeaddrinfo(servinfo);
-
cleanup:
+ if (servinfo)
+ {
+ freeaddrinfo(servinfo);
+ }
+
if (p == NULL)
{
#if defined(_WIN32)
@@ -178,3 +177,45 @@ int SocketSelect(int sock, int *events, int timeout)
return rv;
}
+
+void SocketSetNonblocking(int sock, int nb)
+{
+#if defined(_WIN32)
+ unsigned int yes = nb;
+ ioctlsocket(s, FIONBIO, &yes);
+#else
+ int flags = fcntl(sock, F_GETFL, 0);
+ if (nb)
+ flags |= O_NONBLOCK;
+ else
+ flags &= ~O_NONBLOCK;
+ fcntl(sock, F_SETFL, flags);
+#endif
+}
+
+int SocketGetOpt(int sock, int level, int name, void *val, int *len)
+{
+#if defined(_WIN32)
+ return getsockopt(sock, level, name, (char *) val, len);
+#else
+ socklen_t _len = *len;
+ int rc = getsockopt(sock, level, name, val, &_len);
+ *len = _len;
+ return rc;
+#endif
+}
+
+int SocketGetError(int sock, int *error)
+{
+ int len = sizeof(*error);
+ return SocketGetOpt(sock, SOL_SOCKET, SO_ERROR, error, &len);
+}
+
+int SocketWouldBlock(int error)
+{
+#if defined(_WIN32)
+ return error == WSAEWOULDBLOCK;
+#else
+ return error == EWOULDBLOCK || error == EAGAIN;
+#endif
+}
diff --git a/src/socket.h b/src/socket.h
index e7b1a80..abc67af 100644
--- a/src/socket.h
+++ b/src/socket.h
@@ -6,7 +6,25 @@
#include <stdlib.h>
#include <stdint.h>
-int SocketConnect(const char *host, short port);
+#if defined(_WIN32)
+#include "win32.h"
+#define SocketErrno (WSAGetLastError())
+#define SOCKET_EINPROGRESS (WSAEWOULDBLOCK)
+#else
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/select.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <errno.h>
+#define SocketErrno (errno)
+#define SOCKET_EINPROGRESS (EINPROGRESS)
+#endif
+
+int SocketConnect(const char *host, short port, int nonblocking);
int SocketDisconnect(int sock);
@@ -24,4 +42,10 @@ int64_t SocketRecv(int sock, void *buf, size_t len, int flags);
int64_t SocketSend(int sock, const void *buf, size_t len, int flags);
+void SocketSetNonblocking(int sock, int nb);
+
+int SocketGetError(int sock, int *error);
+
+int SocketWouldBlock(int error);
+
#endif
diff --git a/src/stream_mqtt.c b/src/stream_mqtt.c
index 3864ef3..f2bd9cd 100644
--- a/src/stream_mqtt.c
+++ b/src/stream_mqtt.c
@@ -42,37 +42,39 @@ int64_t StreamWriteMqttString(const_bstring buf, Stream *stream)
return 2 + blength(buf);
}
-int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream)
+int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul,
+ Stream *stream)
{
- size_t multiplier = 1;
unsigned char encodedByte;
- *remainingLength = 0;
do
{
if (StreamRead(&encodedByte, 1, stream) != 1)
return -1;
- *remainingLength += (encodedByte & 127) * multiplier;
- if (multiplier > 128*128*128)
+ *remainingLength += (encodedByte & 127) * (*mul);
+ if ((*mul) > 128*128*128)
return -1;
- multiplier *= 128;
+ (*mul) *= 128;
}
while ((encodedByte & 128) != 0);
+ *mul = 0;
return 0;
}
-int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream)
+int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream)
{
- size_t nbytes = 0;
do
{
- unsigned char encodedByte = remainingLength % 128;
- remainingLength /= 128;
- if (remainingLength > 0)
+ size_t tmp = *remainingLength;
+ unsigned char encodedByte = tmp % 128;
+ tmp /= 128;
+ if (tmp > 0)
encodedByte |= 128;
if (StreamWrite(&encodedByte, 1, stream) != 1)
+ {
return -1;
- ++nbytes;
+ }
+ *remainingLength = tmp;
}
- while (remainingLength > 0);
- return nbytes;
+ while (*remainingLength > 0);
+ return 0;
}
diff --git a/src/stream_mqtt.h b/src/stream_mqtt.h
index a128524..8c8ccb5 100644
--- a/src/stream_mqtt.h
+++ b/src/stream_mqtt.h
@@ -9,7 +9,8 @@
int64_t StreamReadMqttString(bstring *buf, Stream *stream);
int64_t StreamWriteMqttString(const_bstring buf, Stream *stream);
-int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream);
-int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream);
+int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul,
+ Stream *stream);
+int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream);
#endif