aboutsummaryrefslogtreecommitdiff
path: root/src/client.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/client.c')
-rw-r--r--src/client.c183
1 files changed, 154 insertions, 29 deletions
diff --git a/src/client.c b/src/client.c
index 704a53e..e303fe9 100644
--- a/src/client.c
+++ b/src/client.c
@@ -26,6 +26,15 @@
#error define PRId64 for your platform
#endif
+typedef enum MqttClientState MqttClientState;
+
+enum MqttClientState
+{
+ MqttClientStateDisconnected,
+ MqttClientStateConnecting,
+ MqttClientStateConnected,
+};
+
struct MqttClient
{
SocketStream stream;
@@ -80,6 +89,7 @@ struct MqttClient
bstring password;
/* The packet we are receiving */
MqttPacket inPacket;
+ MqttClientState state;
};
static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet);
@@ -135,6 +145,8 @@ MqttClient *MqttClientNew(const char *clientId)
client->maxQueued = 0;
client->maxInflight = 20;
+ client->state = MqttClientStateDisconnected;
+
TAILQ_INIT(&client->outMessages);
TAILQ_INIT(&client->inMessages);
SIMPLEQ_INIT(&client->sendQueue);
@@ -249,6 +261,12 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
assert(client != NULL);
assert(host != NULL);
+ if (client->state != MqttClientStateDisconnected)
+ {
+ LOG_ERROR("client must be disconnected to connect");
+ return -1;
+ }
+
if (client->host)
bassigncstr(client->host, host);
else
@@ -270,10 +288,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
LOG_DEBUG("connecting");
- if ((sock = SocketConnect(host, port)) == -1)
+ if ((sock = SocketConnect(host, port, 1)) == -1)
{
- LOG_ERROR("SocketConnect failed!");
- return -1;
+ if (SocketErrno != SOCKET_EINPROGRESS)
+ {
+ LOG_ERROR("SocketConnect failed!");
+ return -1;
+ }
}
if (SocketStreamOpen(&client->stream, sock) == -1)
@@ -313,6 +334,8 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
MqttClientQueuePacket(client, packet);
+ client->state = MqttClientStateConnecting;
+
return 0;
}
@@ -324,13 +347,14 @@ int MqttClientDisconnect(MqttClient *client)
int MqttClientIsConnected(MqttClient *client)
{
- return client->stream.sock != -1;
+ return client->stream.sock != -1 &&
+ client->state == MqttClientStateConnected;
}
int MqttClientRunOnce(MqttClient *client, int timeout)
{
int rv;
- int events;
+ int events = 0;
assert(client != NULL);
@@ -340,19 +364,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
return -1;
}
- events = EV_READ;
+ if (client->state == MqttClientStateConnected)
+ {
+ events = EV_READ;
- /* Handle outMessages and inMessages, moving queued messages to sendQueue
- if there are less than maxInflight number of messages in flight */
- MqttClientProcessMessageQueue(client);
+ /* Handle outMessages and inMessages, moving queued messages to sendQueue
+ if there are less than maxInflight number of messages in flight */
+ MqttClientProcessMessageQueue(client);
- if (SIMPLEQ_EMPTY(&client->sendQueue))
+ if (SIMPLEQ_EMPTY(&client->sendQueue))
+ {
+ LOG_DEBUG("nothing to write");
+ }
+ else
+ {
+ events |= EV_WRITE;
+ }
+ }
+ else if (client->state == MqttClientStateConnecting)
{
- LOG_DEBUG("nothing to write");
+ events = EV_WRITE;
}
else
{
- events |= EV_WRITE;
+ LOG_ERROR("not connected");
+ return -1;
}
LOG_DEBUG("selecting");
@@ -385,6 +421,19 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
{
LOG_DEBUG("socket writable");
+ if (client->state == MqttClientStateConnecting)
+ {
+ int sockError;
+ SocketGetError(client->stream.sock, &sockError);
+ LOG_DEBUG("sockError: %d", sockError);
+ if (sockError == 0)
+ {
+ LOG_DEBUG("connected!");
+ client->state = MqttClientStateConnected;
+ return 0;
+ }
+ }
+
if (MqttClientSendPacket(client) == -1)
{
LOG_ERROR("MqttClientSendPacket failed");
@@ -689,7 +738,7 @@ int MqttClientSetAuth(MqttClient *client, const char *userName,
{
assert(client != NULL);
- if (MqttClientIsConnected(client))
+ if (client->state == MqttClientStateConnecting)
{
LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect");
return -1;
@@ -767,23 +816,29 @@ static int MqttClientSendPacket(MqttClient *client)
if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1)
{
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
return -1;
}
packet->state = MqttPacketStateWriteRemainingLength;
+ packet->remainingLength = blength(packet->payload);
break;
}
case MqttPacketStateWriteRemainingLength:
{
- if (StreamWriteRemainingLength(blength(packet->payload),
+ if (StreamWriteRemainingLength(&packet->remainingLength,
&client->stream.base) == -1)
{
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
return -1;
}
packet->state = MqttPacketStateWritePayload;
+ packet->remainingLength = blength(packet->payload);
break;
}
@@ -792,15 +847,36 @@ static int MqttClientSendPacket(MqttClient *client)
{
if (packet->payload)
{
- if (StreamWrite(bdata(packet->payload),
- blength(packet->payload),
- &client->stream.base) == -1)
+ int64_t offset = blength(packet->payload) - packet->remainingLength;
+ int64_t nwritten = 0;
+ int towrite = 16*1024;
+
+ if (packet->remainingLength < 16*1024)
+ towrite = packet->remainingLength;
+
+ nwritten = StreamWrite(bdataofs(packet->payload, offset),
+ towrite,
+ &client->stream.base);
+
+ if (nwritten == -1)
{
+ if (SocketWouldBlock(SocketErrno))
+ {
+ return 0;
+ }
return -1;
}
+
+ packet->remainingLength -= nwritten;
+
+ LOG_DEBUG("nwritten:%d", (int) nwritten);
}
- packet->state = MqttPacketStateWriteComplete;
+ if (packet->remainingLength == 0)
+ {
+ LOG_DEBUG("packet payload sent");
+ packet->state = MqttPacketStateWriteComplete;
+ }
break;
}
@@ -812,6 +888,7 @@ static int MqttClientSendPacket(MqttClient *client)
if (packet->type == MqttPacketTypeDisconnect)
{
client->stopped = 1;
+ client->state = MqttClientStateDisconnected;
}
LOG_DEBUG("sent %s", MqttPacketName(packet->type));
@@ -1353,11 +1430,12 @@ static int MqttClientRecvPacket(MqttClient *client)
case MqttPacketStateReadType:
{
unsigned char typeAndFlags;
- int rc;
- if ((rc = StreamReadByte(&typeAndFlags, &client->stream.base)) != 1)
+ if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1)
{
- LOG_ERROR("failed reading packet type: %d", rc);
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
+ LOG_ERROR("failed reading packet type");
return -1;
}
@@ -1372,6 +1450,9 @@ static int MqttClientRecvPacket(MqttClient *client)
}
client->inPacket.state = MqttPacketStateReadRemainingLength;
+ client->inPacket.remainingLength = 0;
+ client->inPacket.remainingLengthMul = 1;
+ client->inPacket.payload = NULL;
break;
}
@@ -1379,12 +1460,20 @@ static int MqttClientRecvPacket(MqttClient *client)
case MqttPacketStateReadRemainingLength:
{
if (StreamReadRemainingLength(&client->inPacket.remainingLength,
+ &client->inPacket.remainingLengthMul,
&client->stream.base) == -1)
{
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
LOG_ERROR("failed to read remaining length");
return -1;
}
+
+ LOG_DEBUG("remainingLength:%lu",
+ client->inPacket.remainingLength);
+
client->inPacket.state = MqttPacketStateReadPayload;
+
break;
}
@@ -1392,21 +1481,57 @@ static int MqttClientRecvPacket(MqttClient *client)
{
if (client->inPacket.remainingLength > 0)
{
- client->inPacket.payload = bfromcstr("");
- ballocmin(client->inPacket.payload,
- client->inPacket.remainingLength+1);
- if (StreamRead(bdata(client->inPacket.payload),
- client->inPacket.remainingLength,
- &client->stream.base) == -1)
+ int64_t nread, offset, toread;
+
+ if (client->inPacket.payload == NULL)
{
+ unsigned char *data;
+ client->inPacket.payload = bfromcstr("");
+ ballocmin(client->inPacket.payload,
+ client->inPacket.remainingLength+1);
+ data = client->inPacket.payload->data;
+ data[client->inPacket.remainingLength] = '\0';
+ }
+
+ offset = blength(client->inPacket.payload);
+
+ toread = 16*1024;
+
+ if (client->inPacket.remainingLength < (size_t) toread)
+ toread = client->inPacket.remainingLength;
+
+ nread = StreamRead(bdataofs(client->inPacket.payload,
+ offset),
+ toread,
+ &client->stream.base);
+
+ if (nread == -1)
+ {
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
LOG_ERROR("failed reading packet payload");
bdestroy(client->inPacket.payload);
client->inPacket.payload = NULL;
return -1;
}
- client->inPacket.payload->slen = client->inPacket.remainingLength;
+ else if (nread == 0)
+ {
+ LOG_ERROR("socket disconnected");
+ bdestroy(client->inPacket.payload);
+ client->inPacket.payload = NULL;
+ return -1;
+ }
+
+ client->inPacket.remainingLength -= nread;
+ client->inPacket.payload->slen += nread;
+
+ LOG_DEBUG("nread:%d", (int) nread);
+ }
+
+ if (client->inPacket.remainingLength == 0)
+ {
+ client->inPacket.state = MqttPacketStateReadComplete;
}
- client->inPacket.state = MqttPacketStateReadComplete;
break;
}