aboutsummaryrefslogtreecommitdiff
path: root/src/serialize.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/serialize.c')
-rw-r--r--src/serialize.c309
1 files changed, 309 insertions, 0 deletions
diff --git a/src/serialize.c b/src/serialize.c
new file mode 100644
index 0000000..d14cf03
--- /dev/null
+++ b/src/serialize.c
@@ -0,0 +1,309 @@
+#include "serialize.h"
+#include "stringbuf.h"
+#include "packet.h"
+#include "stream_mqtt.h"
+#include "log.h"
+
+#include <stdlib.h>
+#include <assert.h>
+
+typedef int (*MqttPacketSerializeFunc)(const MqttPacket *packet,
+ Stream *stream);
+
+static const StringBuf MqttProtocolId = StaticStringBuf("MQTT");
+static const char MqttProtocolLevel = 0x04;
+
+static inline size_t MqttStringLengthSerialized(const StringBuf *s)
+{
+ return 2 + s->len;
+}
+
+static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packet)
+{
+ size_t remainingLength = 0;
+
+ remainingLength += MqttStringLengthSerialized(&MqttProtocolId) + 1 + 1 + 2;
+
+ remainingLength += MqttStringLengthSerialized(&packet->clientId);
+
+ if (packet->connectFlags & 0x80)
+ remainingLength += MqttStringLengthSerialized(&packet->userName);
+
+ if (packet->connectFlags & 0x40)
+ remainingLength += MqttStringLengthSerialized(&packet->password);
+
+ if (packet->connectFlags & 0x04)
+ remainingLength += MqttStringLengthSerialized(&packet->willTopic) +
+ MqttStringLengthSerialized(&packet->willMessage);
+
+ return remainingLength;
+}
+
+static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet)
+{
+ return 2 + MqttStringLengthSerialized(&packet->topicFilter) + 1;
+}
+
+static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet)
+{
+ return 2 + MqttStringLengthSerialized(&packet->topicFilter);
+}
+
+static size_t MqttPacketPublishGetRemainingLength(const MqttPacketPublish *packet)
+{
+ size_t remainingLength = 0;
+
+ remainingLength += MqttStringLengthSerialized(&packet->topicName);
+
+ // Packet id
+ if (MqttPacketPublishQos(packet) == 1 || MqttPacketPublishQos(packet) == 2)
+ {
+ remainingLength += 2;
+ }
+
+ remainingLength += packet->message.len;
+
+ return remainingLength;
+}
+
+static size_t MqttPacketGetRemainingLength(const MqttPacket *packet)
+{
+ switch (packet->type)
+ {
+ case MqttPacketTypeConnect:
+ return MqttPacketConnectGetRemainingLength(
+ (MqttPacketConnect *) packet);
+
+ case MqttPacketTypeSubscribe:
+ return MqttPacketSubscribeGetRemainingLength(
+ (MqttPacketSubscribe *) packet);
+
+ case MqttPacketTypePublish:
+ return MqttPacketPublishGetRemainingLength(
+ (MqttPacketPublish *) packet);
+
+ case MqttPacketTypePubAck:
+ case MqttPacketTypePubRec:
+ case MqttPacketTypePubRel:
+ case MqttPacketTypePubComp:
+ return 2;
+
+ case MqttPacketTypeUnsubscribe:
+ return MqttPacketUnsubscribeGetRemainingLength(
+ (MqttPacketUnsubscribe *) packet);
+
+ default:
+ return 0;
+ }
+}
+
+static int MqttPacketFlags(const MqttPacket *packet)
+{
+ switch (packet->type)
+ {
+ case MqttPacketTypePublish:
+ return ((MqttPacketPublishDup(packet) & 1) << 3) |
+ ((MqttPacketPublishQos(packet) & 3) << 1) |
+ (MqttPacketPublishRetain(packet) & 1);
+
+ case MqttPacketTypePubRel:
+ case MqttPacketTypeSubscribe:
+ case MqttPacketTypeUnsubscribe:
+ return 0x2;
+
+ default:
+ return 0;
+ }
+}
+
+static int MqttPacketBaseSerialize(const MqttPacket *packet, Stream *stream)
+{
+ unsigned char typeAndFlags;
+ size_t remainingLength;
+
+ typeAndFlags = ((packet->type & 0x0F) << 4) |
+ (MqttPacketFlags(packet) & 0x0F);
+ remainingLength = MqttPacketGetRemainingLength(packet);
+
+ LOG_DEBUG("type:%02X (%s) flags:%02X", packet->type,
+ MqttPacketName(packet->type), MqttPacketFlags(packet));
+
+ if (StreamWrite(&typeAndFlags, 1, stream) != 1)
+ return -1;
+
+ if (StreamWriteRemainingLength(remainingLength, stream) == -1)
+ return -1;
+
+ return 0;
+}
+
+static int MqttPacketWithIdSerialize(const MqttPacket *packet, Stream *stream)
+{
+ assert(MqttPacketHasId((const MqttPacket *) packet));
+
+ if (MqttPacketBaseSerialize(packet, stream) == -1)
+ return -1;
+
+ if (StreamWriteUint16Be(packet->id, stream) == -1)
+ return -1;
+
+ return 0;
+}
+
+static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *stream)
+{
+ if (MqttPacketBaseSerialize(&packet->base, stream) == -1)
+ return -1;
+
+ if (StreamWriteMqttStringBuf(&MqttProtocolId, stream) == -1)
+ return -1;
+
+ if (StreamWrite(&MqttProtocolLevel, 1, stream) != 1)
+ return -1;
+
+ if (StreamWrite(&packet->connectFlags, 1, stream) != 1)
+ return -1;
+
+ if (StreamWriteUint16Be(packet->keepAlive, stream) == -1)
+ return -1;
+
+ if (StreamWriteMqttStringBuf(&packet->clientId, stream) == -1)
+ return -1;
+
+ if (packet->connectFlags & 0x04)
+ {
+ if (StreamWriteMqttStringBuf(&packet->willTopic, stream) == -1)
+ return -1;
+
+ if (StreamWriteMqttStringBuf(&packet->willMessage, stream) == -1)
+ return -1;
+ }
+
+ if (packet->connectFlags & 0x80)
+ {
+ if (StreamWriteMqttStringBuf(&packet->userName, stream) == -1)
+ return -1;
+ }
+
+ if (packet->connectFlags & 0x40)
+ {
+ if (StreamWriteMqttStringBuf(&packet->password, stream) == -1)
+ return -1;
+ }
+
+ return 0;
+}
+
+static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream)
+{
+ if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1)
+ return -1;
+
+ if (StreamWriteMqttStringBuf(&packet->topicFilter, stream) == -1)
+ return -1;
+
+ if (StreamWrite(&packet->qos, 1, stream) == -1)
+ return -1;
+
+ return 0;
+}
+
+static int MqttPacketUnsubscribeSerialize(const MqttPacketUnsubscribe *packet, Stream *stream)
+{
+ if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1)
+ return -1;
+
+ if (StreamWriteMqttStringBuf(&packet->topicFilter, stream) == -1)
+ return -1;
+
+ return 0;
+}
+
+static int MqttPacketPublishSerialize(const MqttPacketPublish *packet, Stream *stream)
+{
+ if (MqttPacketBaseSerialize((const MqttPacket *) packet, stream) == -1)
+ return -1;
+
+ if (StreamWriteMqttStringBuf(&packet->topicName, stream) == -1)
+ return -1;
+
+ LOG_DEBUG("qos:%d", MqttPacketPublishQos(packet));
+
+ if (MqttPacketPublishQos(packet) > 0)
+ {
+ if (StreamWriteUint16Be(packet->base.id, stream) == -1)
+ return -1;
+ }
+
+ if (StreamWrite(packet->message.data, packet->message.len, stream) == -1)
+ return -1;
+
+ return 0;
+}
+
+int MqttPacketSerialize(const MqttPacket *packet, Stream *stream)
+{
+ MqttPacketSerializeFunc f = NULL;
+
+ switch (packet->type)
+ {
+ case MqttPacketTypeConnect:
+ f = (MqttPacketSerializeFunc) MqttPacketConnectSerialize;
+ break;
+
+ case MqttPacketTypeConnAck:
+ break;
+
+ case MqttPacketTypePublish:
+ f = (MqttPacketSerializeFunc) MqttPacketPublishSerialize;
+ break;
+
+ case MqttPacketTypePubAck:
+ f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
+ break;
+
+ case MqttPacketTypePubRec:
+ f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
+ break;
+
+ case MqttPacketTypePubRel:
+ f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
+ break;
+
+ case MqttPacketTypePubComp:
+ f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
+ break;
+
+ case MqttPacketTypeSubscribe:
+ f = (MqttPacketSerializeFunc) MqttPacketSubscribeSerialize;
+ break;
+
+ case MqttPacketTypeSubAck:
+ break;
+
+ case MqttPacketTypeUnsubscribe:
+ f = (MqttPacketSerializeFunc) MqttPacketUnsubscribeSerialize;
+ break;
+
+ case MqttPacketTypeUnsubAck:
+ break;
+
+ case MqttPacketTypePingReq:
+ f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize;
+ break;
+
+ case MqttPacketTypePingResp:
+ break;
+
+ case MqttPacketTypeDisconnect:
+ f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize;
+ break;
+
+ default:
+ return -1;
+ }
+
+ assert(f != NULL && "no serializer");
+
+ return f(packet, stream);
+}