aboutsummaryrefslogtreecommitdiff
path: root/src/deserialize.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/deserialize.c')
-rw-r--r--src/deserialize.c278
1 files changed, 278 insertions, 0 deletions
diff --git a/src/deserialize.c b/src/deserialize.c
new file mode 100644
index 0000000..19be4ce
--- /dev/null
+++ b/src/deserialize.c
@@ -0,0 +1,278 @@
+#include "deserialize.h"
+#include "packet.h"
+#include "stream_mqtt.h"
+#include "log.h"
+
+#include <stdlib.h>
+#include <assert.h>
+
+typedef int (*MqttPacketDeserializeFunc)(MqttPacket **packet, Stream *stream);
+
+static int MqttPacketWithIdDeserialize(MqttPacket **packet, Stream *stream)
+{
+ size_t remainingLength = 0;
+
+ if (StreamReadRemainingLength(&remainingLength, stream) == -1)
+ return -1;
+
+ if (remainingLength != 2)
+ return -1;
+
+ if (StreamReadUint16Be(&(*packet)->id, stream) == -1)
+ return -1;
+
+ return 0;
+}
+
+static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stream)
+{
+ size_t remainingLength = 0;
+
+ if (StreamReadRemainingLength(&remainingLength, stream) == -1)
+ return -1;
+
+ if (remainingLength != 2)
+ return -1;
+
+ if (StreamRead(&(*packet)->connAckFlags, 1, stream) != 1)
+ return -1;
+
+ if (StreamRead(&(*packet)->returnCode, 1, stream) != 1)
+ return -1;
+
+ return 0;
+}
+
+static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream)
+{
+ size_t remainingLength = 0;
+
+ if (StreamReadRemainingLength(&remainingLength, stream) == -1)
+ return -1;
+
+ // 2 bytes for packet id and 1 byte for single return code
+ if (remainingLength != 3)
+ return -1;
+
+ if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1)
+ return -1;
+
+ if (StreamRead(&((*packet)->returnCode), 1, stream) == -1)
+ return -1;
+
+ return 0;
+}
+
+static int MqttPacketTypeUnsubAckDeserialize(MqttPacket **packet, Stream *stream)
+{
+ size_t remainingLength = 0;
+
+ if (StreamReadRemainingLength(&remainingLength, stream) == -1)
+ return -1;
+
+ if (remainingLength != 2)
+ return -1;
+
+ if (StreamReadUint16Be(&(*packet)->id, stream) == -1)
+ return -1;
+
+ return 0;
+}
+
+static int MqttPacketPublishDeserialize(MqttPacketPublish **packet, Stream *stream)
+{
+ size_t remainingLength = 0;
+ size_t payloadSize = 0;
+
+ if (StreamReadRemainingLength(&remainingLength, stream) == -1)
+ return -1;
+
+ if (StreamReadMqttStringBuf(&(*packet)->topicName, stream) == -1)
+ return -1;
+
+ LOG_DEBUG("remainingLength:%lu", remainingLength);
+
+ payloadSize = remainingLength - (*packet)->topicName.len - 2;
+
+ LOG_DEBUG("qos:%d payloadSize:%lu", MqttPacketPublishQos(*packet),
+ payloadSize);
+
+ if (MqttPacketHasId((const MqttPacket *) *packet))
+ {
+ LOG_DEBUG("packet has id");
+ payloadSize -= 2;
+ if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1)
+ {
+ return -1;
+ }
+ }
+
+ LOG_DEBUG("reading payload payloadSize:%lu\n", payloadSize);
+
+ if (StringBufInit(&((*packet)->message), payloadSize) == -1)
+ return -1;
+
+ if (StreamRead((*packet)->message.data, payloadSize, stream) == -1)
+ return -1;
+
+ (*packet)->message.len = payloadSize;
+
+ return 0;
+}
+
+static int MqttPacketGenericDeserializer(MqttPacket **packet, Stream *stream)
+{
+ size_t remainingLength = 0;
+ char buffer[256];
+
+ (void) packet;
+
+ if (StreamReadRemainingLength(&remainingLength, stream) == -1)
+ return -1;
+
+ while (remainingLength > 0)
+ {
+ size_t l = sizeof(buffer);
+
+ if (remainingLength < l)
+ l = remainingLength;
+
+ if (StreamRead(buffer, l, stream) != (int64_t) l)
+ return -1;
+
+ remainingLength -= l;
+ }
+
+ return 0;
+}
+
+static int ValidateFlags(int type, int flags)
+{
+ int rv = 0;
+
+ switch (type)
+ {
+ case MqttPacketTypePublish:
+ {
+ int qos = (flags >> 1) & 2;
+ if (qos >= 0 && qos <= 2)
+ rv = 1;
+ break;
+ }
+
+ case MqttPacketTypePubRel:
+ case MqttPacketTypeSubscribe:
+ case MqttPacketTypeUnsubscribe:
+ if (flags == 2)
+ {
+ rv = 1;
+ }
+ break;
+
+ default:
+ if (flags == 0)
+ {
+ rv = 1;
+ }
+ break;
+ }
+
+ return rv;
+}
+
+int MqttPacketDeserialize(MqttPacket **packet, Stream *stream)
+{
+ MqttPacketDeserializeFunc deserializer = NULL;
+ char typeAndFlags;
+ int type;
+ int flags;
+ int rv;
+
+ if (StreamRead(&typeAndFlags, 1, stream) != 1)
+ return -1;
+
+ type = (typeAndFlags & 0xF0) >> 4;
+ flags = (typeAndFlags & 0x0F);
+
+ if (!ValidateFlags(type, flags))
+ {
+ return -1;
+ }
+
+ switch (type)
+ {
+ case MqttPacketTypeConnect:
+ break;
+
+ case MqttPacketTypeConnAck:
+ deserializer = (MqttPacketDeserializeFunc) MqttPacketConnAckDeserialize;
+ break;
+
+ case MqttPacketTypePublish:
+ deserializer = (MqttPacketDeserializeFunc) MqttPacketPublishDeserialize;
+ break;
+
+ case MqttPacketTypePubAck:
+ deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
+ break;
+
+ case MqttPacketTypePubRec:
+ deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
+ break;
+
+ case MqttPacketTypePubRel:
+ deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
+ break;
+
+ case MqttPacketTypePubComp:
+ deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
+ break;
+
+ case MqttPacketTypeSubscribe:
+ break;
+
+ case MqttPacketTypeSubAck:
+ deserializer = (MqttPacketDeserializeFunc) MqttPacketSubAckDeserialize;
+ break;
+
+ case MqttPacketTypeUnsubscribe:
+ break;
+
+ case MqttPacketTypeUnsubAck:
+ deserializer = (MqttPacketDeserializeFunc) MqttPacketTypeUnsubAckDeserialize;
+ break;
+
+ case MqttPacketTypePingReq:
+ break;
+
+ case MqttPacketTypePingResp:
+ break;
+
+ case MqttPacketTypeDisconnect:
+ break;
+
+ default:
+ return -1;
+ }
+
+ if (!deserializer)
+ {
+ deserializer = MqttPacketGenericDeserializer;
+ }
+
+ *packet = MqttPacketNew(type);
+
+ if (!*packet)
+ return -1;
+
+ if (type == MqttPacketTypePublish)
+ {
+ MqttPacketPublishDup(*packet) = (flags >> 3) & 1;
+ MqttPacketPublishQos(*packet) = (flags >> 1) & 3;
+ MqttPacketPublishRetain(*packet) = flags & 1;
+ }
+
+ rv = deserializer(packet, stream);
+
+ return rv;
+}