From 8198f6d6beb3c8af3768236070089112c094b92e Mon Sep 17 00:00:00 2001 From: Oskari Timperi Date: Sun, 19 Feb 2017 16:03:56 +0200 Subject: Add MqttClientSubscribeMany() and make necessary API changes --- src/client.c | 52 +++++++++++++++++++++------ src/deserialize.c | 17 +++++---- src/mqtt.h | 4 +++ src/packet.c | 2 +- src/packet.h | 6 ++-- src/serialize.c | 26 +++++++++++--- test/interop/overlapping_subscriptions_test.c | 5 +-- test/interop/subscribe_failure_test.c | 2 +- test/interop/testclient.c | 41 +++++++++++++++++++-- test/interop/testclient.h | 6 +++- tools/sub.c | 3 +- 11 files changed, 131 insertions(+), 33 deletions(-) diff --git a/src/client.c b/src/client.c index c4bd499..4c4ed7d 100644 --- a/src/client.c +++ b/src/client.c @@ -414,12 +414,21 @@ int MqttClientRun(MqttClient *client) int MqttClientSubscribe(MqttClient *client, const char *topicFilter, int qos) +{ + return MqttClientSubscribeMany(client, &topicFilter, &qos, 1); +} + +int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, + int *qos, size_t count) { MqttPacketSubscribe *packet = NULL; + size_t i; assert(client != NULL); - assert(topicFilter != NULL); - assert(qos >= 0 && qos <= 2); + assert(topicFilters != NULL); + assert(*topicFilters != NULL); + assert(qos != NULL); + assert(count > 0); packet = (MqttPacketSubscribe *) MqttPacketWithIdNew( MqttPacketTypeSubscribe, MqttClientNextPacketId(client)); @@ -427,8 +436,18 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, if (!packet) return -1; - packet->topicFilter = bfromcstr(topicFilter); - packet->qos = qos; + packet->topicFilters = bstrListCreate(); + bstrListAllocMin(packet->topicFilters, count); + + packet->qos = (int *) malloc(sizeof(int) * count); + + for (i = 0; i < count; ++i) + { + packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]); + ++packet->topicFilters->qty; + } + + memcpy(packet->qos, qos, sizeof(int) * count); MqttClientQueuePacket(client, (MqttPacket *) packet); @@ -662,16 +681,27 @@ static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet) } else { - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - if (client->onSubscribe) { - LOG_DEBUG("calling onSubscribe id:%d rc:%d", MqttPacketId(packet), - packet->returnCode); - client->onSubscribe(client, MqttPacketId(packet), - packet->returnCode); + MqttPacketSubscribe *sub2; + int i; + + sub2 = (MqttPacketSubscribe *) sub; + + for (i = 0; i < sub2->topicFilters->qty; ++i) + { + const char *filter = bdata(sub2->topicFilters->entry[i]); + int rc = packet->returnCode[i]; + + LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d", + MqttPacketId(packet), filter, rc); + + client->onSubscribe(client, MqttPacketId(packet), filter, rc); + } } + + TAILQ_REMOVE(&client->outMessages, sub, messages); + MqttPacketFree(sub); } } diff --git a/src/deserialize.c b/src/deserialize.c index aaff490..96d7789 100644 --- a/src/deserialize.c +++ b/src/deserialize.c @@ -46,19 +46,24 @@ static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stre static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream) { size_t remainingLength = 0; + size_t i; if (StreamReadRemainingLength(&remainingLength, stream) == -1) return -1; - /* 2 bytes for packet id and 1 byte for single return code */ - if (remainingLength != 3) - return -1; - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) return -1; - if (StreamRead(&((*packet)->returnCode), 1, stream) == -1) - return -1; + remainingLength -= 2; + + (*packet)->returnCode = (unsigned char *) malloc( + sizeof(*(*packet)->returnCode) * remainingLength); + + for (i = 0; i < remainingLength; ++i) + { + if (StreamRead(&((*packet)->returnCode[i]), 1, stream) == -1) + return -1; + } return 0; } diff --git a/src/mqtt.h b/src/mqtt.h index f07ff3c..ad84aaf 100644 --- a/src/mqtt.h +++ b/src/mqtt.h @@ -33,6 +33,7 @@ typedef void (*MqttClientOnConnectCallback)(MqttClient *client, typedef void (*MqttClientOnSubscribeCallback)(MqttClient *client, int id, + const char *topicFilter, MqttSubscriptionStatus status); typedef void (*MqttClientOnUnsubscribeCallback)(MqttClient *client, int id); @@ -82,6 +83,9 @@ int MqttClientRun(MqttClient *client); int MqttClientSubscribe(MqttClient *client, const char *topicFilter, int qos); +int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, + int *qos, size_t count); + int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter); int MqttClientPublish(MqttClient *client, int qos, int retain, diff --git a/src/packet.c b/src/packet.c index 5de7d97..47aa689 100644 --- a/src/packet.c +++ b/src/packet.c @@ -96,7 +96,7 @@ void MqttPacketFree(MqttPacket *packet) else if (MqttPacketType(packet) == MqttPacketTypeSubscribe) { MqttPacketSubscribe *p = (MqttPacketSubscribe *) packet; - bdestroy(p->topicFilter); + bstrListDestroy(p->topicFilters); } else if (MqttPacketType(packet) == MqttPacketTypeUnsubscribe) { diff --git a/src/packet.h b/src/packet.h index 4fe7b74..7ab4f73 100644 --- a/src/packet.h +++ b/src/packet.h @@ -92,8 +92,8 @@ typedef struct MqttPacketSubscribe MqttPacketSubscribe; struct MqttPacketSubscribe { MqttPacket base; - bstring topicFilter; - char qos; + struct bstrList *topicFilters; + int *qos; }; typedef struct MqttPacketSubAck MqttPacketSubAck; @@ -101,7 +101,7 @@ typedef struct MqttPacketSubAck MqttPacketSubAck; struct MqttPacketSubAck { MqttPacket base; - unsigned char returnCode; + unsigned char *returnCode; }; typedef struct MqttPacketUnsubscribe MqttPacketUnsubscribe; diff --git a/src/serialize.c b/src/serialize.c index b6c8cbc..3378b80 100644 --- a/src/serialize.c +++ b/src/serialize.c @@ -42,7 +42,16 @@ static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packe static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet) { - return 2 + MqttStringLengthSerialized(packet->topicFilter) + 1; + size_t remaining = 2; + int i; + + for (i = 0; i < packet->topicFilters->qty; ++i) + { + remaining += MqttStringLengthSerialized(packet->topicFilters->entry[i]); + remaining += 1; + } + + return remaining; } static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet) @@ -197,14 +206,21 @@ static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *s static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream) { + int i; + if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) return -1; - if (StreamWriteMqttString(packet->topicFilter, stream) == -1) - return -1; + for (i = 0; i < packet->topicFilters->qty; ++i) + { + unsigned char qos = (unsigned char) packet->qos[i]; - if (StreamWrite(&packet->qos, 1, stream) == -1) - return -1; + if (StreamWriteMqttString(packet->topicFilters->entry[i], stream) == -1) + return -1; + + if (StreamWrite(&qos, 1, stream) == -1) + return -1; + } return 0; } diff --git a/test/interop/overlapping_subscriptions_test.c b/test/interop/overlapping_subscriptions_test.c index c6e5da0..ec6f061 100644 --- a/test/interop/overlapping_subscriptions_test.c +++ b/test/interop/overlapping_subscriptions_test.c @@ -7,11 +7,12 @@ TEST overlapping_subscriptions_test() { TestClient *client; int count; + const char *mywildtopics[] = { wildtopics[6], wildtopics[0] }; + int qos[] = { 2, 1 }; client = TestClientNew("clienta"); ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); - ASSERT(TestClientSubscribe(client, wildtopics[6], 2)); - ASSERT(TestClientSubscribe(client, wildtopics[0], 1)); + ASSERT(TestClientSubscribeMany(client, mywildtopics, qos, 2)); ASSERT(TestClientPublish(client, 2, 0, topics[3], "overlapping topic filters")); ASSERT(TestClientWait(client, 1000)); diff --git a/test/interop/subscribe_failure_test.c b/test/interop/subscribe_failure_test.c index 84d3ebc..07a0d1b 100644 --- a/test/interop/subscribe_failure_test.c +++ b/test/interop/subscribe_failure_test.c @@ -10,7 +10,7 @@ TEST subscribe_failure_test() client = TestClientNew("clienta"); ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); ASSERT_FALSE(TestClientSubscribe(client, nosubscribe_topics[0], 2)); - ASSERT_EQ(MqttSubscriptionFailure, client->subStatus); + ASSERT_EQ(MqttSubscriptionFailure, client->subStatus[0]); TestClientDisconnect(client); TestClientFree(client); diff --git a/test/interop/testclient.c b/test/interop/testclient.c index c27945d..8d616f6 100644 --- a/test/interop/testclient.c +++ b/test/interop/testclient.c @@ -14,11 +14,12 @@ static void TestClientOnConnect(MqttClient *client, } static void TestClientOnSubscribe(MqttClient *client, int id, + const char *filter, MqttSubscriptionStatus status) { TestClient *testClient = (TestClient *) MqttClientGetUserData(client); testClient->subId = id; - testClient->subStatus = status; + testClient->subStatus[testClient->subCount++] = status; } static void TestClientOnPublish(MqttClient *client, int id) @@ -132,6 +133,7 @@ int TestClientSubscribe(TestClient *client, const char *topicFilter, int qos) int id = MqttClientSubscribe(client->client, topicFilter, qos); client->subId = -1; + client->subCount = 0; while (MqttClientRunOnce(client->client, -1) != -1) { @@ -147,7 +149,42 @@ int TestClientSubscribe(TestClient *client, const char *topicFilter, int qos) } } - return client->subStatus != MqttSubscriptionFailure; + return client->subStatus[0] != MqttSubscriptionFailure; +} + +int TestClientSubscribeMany(TestClient *client, const char **topicFilter, + int *qos, size_t count) +{ + int id = MqttClientSubscribeMany(client->client, topicFilter, qos, count); + int fail = 0, i; + + client->subId = -1; + client->subCount = 0; + + while (MqttClientRunOnce(client->client, -1) != -1) + { + if (client->subId != -1) + { + if (client->subId != id) + { + printf( + "WARNING: subscription id mismatch: expected %d, got %d\n", + id, client->subId); + } + break; + } + } + + for (i = 0; i < client->subCount; ++i) + { + if (client->subStatus[i] == MqttSubscriptionFailure) + { + fail = 1; + break; + } + } + + return !fail; } int TestClientPublish(TestClient *client, int qos, int retain, diff --git a/test/interop/testclient.h b/test/interop/testclient.h index 70805c6..2aa229b 100644 --- a/test/interop/testclient.h +++ b/test/interop/testclient.h @@ -29,7 +29,8 @@ struct TestClient /* OnSubscribe */ int subId; - MqttSubscriptionStatus subStatus; + MqttSubscriptionStatus subStatus[16]; + int subCount; /* OnPublish */ int pubId; @@ -54,6 +55,9 @@ void TestClientDisconnect(TestClient *client); int TestClientSubscribe(TestClient *client, const char *topicFilter, int qos); +int TestClientSubscribeMany(TestClient *client, const char **topicFilter, + int *qos, size_t count); + int TestClientPublish(TestClient *client, int qos, int retain, const char *topic, const char *message); diff --git a/tools/sub.c b/tools/sub.c index b556a27..ebf3372 100644 --- a/tools/sub.c +++ b/tools/sub.c @@ -21,7 +21,8 @@ void onConnect(MqttClient *client, MqttConnectionStatus status, MqttClientSubscribe(client, options->topic, options->qos); } -void onSubscribe(MqttClient *client, int id, MqttSubscriptionStatus status) +void onSubscribe(MqttClient *client, int id, const char *filter, + MqttSubscriptionStatus status) { (void) client; printf("onSubscribe id=%d status=%d\n", id, status); -- cgit v1.2.3