diff options
Diffstat (limited to 'src/client.c')
| -rw-r--r-- | src/client.c | 183 |
1 files changed, 154 insertions, 29 deletions
diff --git a/src/client.c b/src/client.c index 704a53e..e303fe9 100644 --- a/src/client.c +++ b/src/client.c @@ -26,6 +26,15 @@ #error define PRId64 for your platform #endif +typedef enum MqttClientState MqttClientState; + +enum MqttClientState +{ + MqttClientStateDisconnected, + MqttClientStateConnecting, + MqttClientStateConnected, +}; + struct MqttClient { SocketStream stream; @@ -80,6 +89,7 @@ struct MqttClient bstring password; /* The packet we are receiving */ MqttPacket inPacket; + MqttClientState state; }; static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); @@ -135,6 +145,8 @@ MqttClient *MqttClientNew(const char *clientId) client->maxQueued = 0; client->maxInflight = 20; + client->state = MqttClientStateDisconnected; + TAILQ_INIT(&client->outMessages); TAILQ_INIT(&client->inMessages); SIMPLEQ_INIT(&client->sendQueue); @@ -249,6 +261,12 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, assert(client != NULL); assert(host != NULL); + if (client->state != MqttClientStateDisconnected) + { + LOG_ERROR("client must be disconnected to connect"); + return -1; + } + if (client->host) bassigncstr(client->host, host); else @@ -270,10 +288,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, LOG_DEBUG("connecting"); - if ((sock = SocketConnect(host, port)) == -1) + if ((sock = SocketConnect(host, port, 1)) == -1) { - LOG_ERROR("SocketConnect failed!"); - return -1; + if (SocketErrno != SOCKET_EINPROGRESS) + { + LOG_ERROR("SocketConnect failed!"); + return -1; + } } if (SocketStreamOpen(&client->stream, sock) == -1) @@ -313,6 +334,8 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, MqttClientQueuePacket(client, packet); + client->state = MqttClientStateConnecting; + return 0; } @@ -324,13 +347,14 @@ int MqttClientDisconnect(MqttClient *client) int MqttClientIsConnected(MqttClient *client) { - return client->stream.sock != -1; + return client->stream.sock != -1 && + client->state == MqttClientStateConnected; } int MqttClientRunOnce(MqttClient *client, int timeout) { int rv; - int events; + int events = 0; assert(client != NULL); @@ -340,19 +364,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout) return -1; } - events = EV_READ; + if (client->state == MqttClientStateConnected) + { + events = EV_READ; - /* Handle outMessages and inMessages, moving queued messages to sendQueue - if there are less than maxInflight number of messages in flight */ - MqttClientProcessMessageQueue(client); + /* Handle outMessages and inMessages, moving queued messages to sendQueue + if there are less than maxInflight number of messages in flight */ + MqttClientProcessMessageQueue(client); - if (SIMPLEQ_EMPTY(&client->sendQueue)) + if (SIMPLEQ_EMPTY(&client->sendQueue)) + { + LOG_DEBUG("nothing to write"); + } + else + { + events |= EV_WRITE; + } + } + else if (client->state == MqttClientStateConnecting) { - LOG_DEBUG("nothing to write"); + events = EV_WRITE; } else { - events |= EV_WRITE; + LOG_ERROR("not connected"); + return -1; } LOG_DEBUG("selecting"); @@ -385,6 +421,19 @@ int MqttClientRunOnce(MqttClient *client, int timeout) { LOG_DEBUG("socket writable"); + if (client->state == MqttClientStateConnecting) + { + int sockError; + SocketGetError(client->stream.sock, &sockError); + LOG_DEBUG("sockError: %d", sockError); + if (sockError == 0) + { + LOG_DEBUG("connected!"); + client->state = MqttClientStateConnected; + return 0; + } + } + if (MqttClientSendPacket(client) == -1) { LOG_ERROR("MqttClientSendPacket failed"); @@ -689,7 +738,7 @@ int MqttClientSetAuth(MqttClient *client, const char *userName, { assert(client != NULL); - if (MqttClientIsConnected(client)) + if (client->state == MqttClientStateConnecting) { LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect"); return -1; @@ -767,23 +816,29 @@ static int MqttClientSendPacket(MqttClient *client) if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1) { + if (SocketWouldBlock(SocketErrno)) + return 0; return -1; } packet->state = MqttPacketStateWriteRemainingLength; + packet->remainingLength = blength(packet->payload); break; } case MqttPacketStateWriteRemainingLength: { - if (StreamWriteRemainingLength(blength(packet->payload), + if (StreamWriteRemainingLength(&packet->remainingLength, &client->stream.base) == -1) { + if (SocketWouldBlock(SocketErrno)) + return 0; return -1; } packet->state = MqttPacketStateWritePayload; + packet->remainingLength = blength(packet->payload); break; } @@ -792,15 +847,36 @@ static int MqttClientSendPacket(MqttClient *client) { if (packet->payload) { - if (StreamWrite(bdata(packet->payload), - blength(packet->payload), - &client->stream.base) == -1) + int64_t offset = blength(packet->payload) - packet->remainingLength; + int64_t nwritten = 0; + int towrite = 16*1024; + + if (packet->remainingLength < 16*1024) + towrite = packet->remainingLength; + + nwritten = StreamWrite(bdataofs(packet->payload, offset), + towrite, + &client->stream.base); + + if (nwritten == -1) { + if (SocketWouldBlock(SocketErrno)) + { + return 0; + } return -1; } + + packet->remainingLength -= nwritten; + + LOG_DEBUG("nwritten:%d", (int) nwritten); } - packet->state = MqttPacketStateWriteComplete; + if (packet->remainingLength == 0) + { + LOG_DEBUG("packet payload sent"); + packet->state = MqttPacketStateWriteComplete; + } break; } @@ -812,6 +888,7 @@ static int MqttClientSendPacket(MqttClient *client) if (packet->type == MqttPacketTypeDisconnect) { client->stopped = 1; + client->state = MqttClientStateDisconnected; } LOG_DEBUG("sent %s", MqttPacketName(packet->type)); @@ -1353,11 +1430,12 @@ static int MqttClientRecvPacket(MqttClient *client) case MqttPacketStateReadType: { unsigned char typeAndFlags; - int rc; - if ((rc = StreamReadByte(&typeAndFlags, &client->stream.base)) != 1) + if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1) { - LOG_ERROR("failed reading packet type: %d", rc); + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet type"); return -1; } @@ -1372,6 +1450,9 @@ static int MqttClientRecvPacket(MqttClient *client) } client->inPacket.state = MqttPacketStateReadRemainingLength; + client->inPacket.remainingLength = 0; + client->inPacket.remainingLengthMul = 1; + client->inPacket.payload = NULL; break; } @@ -1379,12 +1460,20 @@ static int MqttClientRecvPacket(MqttClient *client) case MqttPacketStateReadRemainingLength: { if (StreamReadRemainingLength(&client->inPacket.remainingLength, + &client->inPacket.remainingLengthMul, &client->stream.base) == -1) { + if (SocketWouldBlock(SocketErrno)) + return 0; LOG_ERROR("failed to read remaining length"); return -1; } + + LOG_DEBUG("remainingLength:%lu", + client->inPacket.remainingLength); + client->inPacket.state = MqttPacketStateReadPayload; + break; } @@ -1392,21 +1481,57 @@ static int MqttClientRecvPacket(MqttClient *client) { if (client->inPacket.remainingLength > 0) { - client->inPacket.payload = bfromcstr(""); - ballocmin(client->inPacket.payload, - client->inPacket.remainingLength+1); - if (StreamRead(bdata(client->inPacket.payload), - client->inPacket.remainingLength, - &client->stream.base) == -1) + int64_t nread, offset, toread; + + if (client->inPacket.payload == NULL) { + unsigned char *data; + client->inPacket.payload = bfromcstr(""); + ballocmin(client->inPacket.payload, + client->inPacket.remainingLength+1); + data = client->inPacket.payload->data; + data[client->inPacket.remainingLength] = '\0'; + } + + offset = blength(client->inPacket.payload); + + toread = 16*1024; + + if (client->inPacket.remainingLength < (size_t) toread) + toread = client->inPacket.remainingLength; + + nread = StreamRead(bdataofs(client->inPacket.payload, + offset), + toread, + &client->stream.base); + + if (nread == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; LOG_ERROR("failed reading packet payload"); bdestroy(client->inPacket.payload); client->inPacket.payload = NULL; return -1; } - client->inPacket.payload->slen = client->inPacket.remainingLength; + else if (nread == 0) + { + LOG_ERROR("socket disconnected"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + + client->inPacket.remainingLength -= nread; + client->inPacket.payload->slen += nread; + + LOG_DEBUG("nread:%d", (int) nread); + } + + if (client->inPacket.remainingLength == 0) + { + client->inPacket.state = MqttPacketStateReadComplete; } - client->inPacket.state = MqttPacketStateReadComplete; break; } |
