From 5b25e192705d4040b51665e49627f8ea037186ee Mon Sep 17 00:00:00 2001 From: Oskari Timperi Date: Sat, 24 Mar 2018 18:29:11 +0200 Subject: Initial support for embedded messages --- src/protobuf/gen.nim | 130 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 115 insertions(+), 15 deletions(-) (limited to 'src') diff --git a/src/protobuf/gen.nim b/src/protobuf/gen.nim index 881b275..76d36fe 100644 --- a/src/protobuf/gen.nim +++ b/src/protobuf/gen.nim @@ -72,9 +72,16 @@ proc getFieldType(field: NimNode): FieldType = let node = findColonExpr(field, "ftype") result = FieldType(node[1].intVal) +proc getFieldTypeName(field: NimNode): string = + let node = findColonExpr(field, "typeName") + result = $node[1] + proc getFullFieldType(field: NimNode): NimNode = let ftype = getFieldType(field) - result = toNimNode(ftype) + if ftype == FieldType.Message: + result = ident(getFieldTypeName(field)) + else: + result = toNimNode(ftype) if isRepeated(field): result = nnkBracketExpr.newTree(ident("seq"), result) @@ -102,7 +109,7 @@ proc defaultValue(field: NimNode): NimNode = of FieldType.Bool: result = newLit(false) of FieldType.String: result = newLit("") of FieldType.Group: result = newLit("NOTIMPLEMENTED") - of FieldType.Message: result = newLit("TODO") + of FieldType.Message: result = newCall(ident("new" & getFieldTypeName(field))) of FieldType.Bytes: result = newCall(ident("bytes"), newLit("")) of FieldType.UInt32: result = newLit(0'u32) of FieldType.Enum: result = newLit("TODO") @@ -242,7 +249,11 @@ proc generateAddToFieldProc(desc, field: NimNode): NimNode = add(body, newCall("incl", newDotExpr(ident("message"), ident("hasField")), newLit(getFieldNumber(field)))) - let ftype = toNimNode(getFieldType(field)) + let ftype = + if getFieldType(field) == FieldType.Message: + ident(getFieldTypeName(field)) + else: + toNimNode(getFieldType(field)) result = newProc(postfix(ident("add" & capitalizeAscii(fieldName)), "*"), @[newEmptyNode(), newIdentDefs(ident("message"), @@ -258,7 +269,11 @@ proc genWriteField(field: NimNode): NimNode = let number = getFieldNumber(field) - writer = ident("write" & $getFieldType(field)) + writer = + if getFieldType(field) == FieldType.Message: + ident("write" & getFieldTypeName(field)) + else: + ident("write" & $getFieldType(field)) fname = newDotExpr(ident("message"), ident(getFieldName(field))) wiretype = ident(wiretype(field)) @@ -266,6 +281,9 @@ proc genWriteField(field: NimNode): NimNode = result.add quote do: writeTag(stream, `number`, `wiretype`) `writer`(stream, `fname`) + if getFieldType(field) == FieldType.Message: + insert(result[^1], 1, newCall(ident("writeVarint"), ident("stream"), + newCall(ident("sizeOf" & getFieldTypeName(field)), fname))) else: if isPacked(field): result.add quote do: @@ -274,10 +292,14 @@ proc genWriteField(field: NimNode): NimNode = for value in `fname`: `writer`(stream, value) else: + let valueId = ident("value") result.add quote do: - for value in `fname`: + for `valueId` in `fname`: writeTag(stream, `number`, `wiretype`) - `writer`(stream, value) + `writer`(stream, `valueId`) + if getFieldType(field) == FieldType.Message: + insert(result[^1][^1], 1, newCall(ident("writeVarint"), ident("stream"), + newCall(ident("sizeOf" & getFieldTypeName(field)), valueId))) proc generateWriteMessageProc(desc: NimNode): NimNode = let body = newStmtList() @@ -305,7 +327,7 @@ proc generateReadMessageProc(desc: NimNode): NimNode = let resultId = ident("result") let body = newStmtList( - newCall(ident("new"), resultId) + newAssignment(resultId, newCall(ident("new" & getMessageName(desc)))) ) let tagid = ident("tag") @@ -325,7 +347,11 @@ proc generateReadMessageProc(desc: NimNode): NimNode = let number = getFieldNumber(field) if isRepeated(field): let adder = ident("add" & capitalizeAscii(getFieldName(field))) - let reader = ident("read" & $getFieldType(field)) + let reader = + if getFieldType(field) == FieldType.Message: + ident("read" & getFieldTypeName(field)) + else: + ident("read" & $getFieldType(field)) if isNumeric(getFieldType(field)): add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: if `wiretypeId` == WireType.LengthDelimited: @@ -343,15 +369,36 @@ proc generateReadMessageProc(desc: NimNode): NimNode = `adder`(`resultId`, `reader`(stream)) )) else: - add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: - `adder`(`resultId`, `reader`(stream)) - )) + if getFieldType(field) == FieldType.Message: + add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: + let size = readVarint(stream) + let data = readStr(stream, int(size)) + let stream2 = newProtobufStream(newStringStream(data)) + `adder`(`resultId`, `reader`(stream2)) + )) + else: + add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: + `adder`(`resultId`, `reader`(stream)) + )) else: let setter = ident("set" & capitalizeAscii(getFieldName(field))) - let reader = ident("read" & $getFieldType(field)) - add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: - `setter`(`resultId`, `reader`(stream)) - )) + let reader = + if getFieldType(field) == FieldType.Message: + ident("read" & getFieldTypeName(field)) + else: + ident("read" & $getFieldType(field)) + if getFieldType(field) == FieldType.Message: + add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: + let size = readVarint(stream) + let data = readStr(stream, int(size)) + let stream2 = newProtobufStream(newStringStream(data)) + `setter`(`resultId`, `reader`(stream2)) + )) + else: + add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: + `setter`(`resultId`, `reader`(stream)) + )) + # TODO: generate code to skip unknown fields add(caseNode, nnkElse.newTree(quote do: @@ -362,6 +409,58 @@ proc generateReadMessageProc(desc: NimNode): NimNode = @[ident(name), newIdentDefs(ident("stream"), ident("ProtobufStream"))], body) +proc generateSizeOfMessageProc(desc: NimNode): NimNode = + let + name = getMessageName(desc) + body = newStmtList() + messageId = ident("message") + resultId = ident("result") + + for field in fields(desc): + let + hasproc = ident("has" & capitalizeAscii(getFieldName(field))) + sizeofproc = + if getFieldType(field) == FieldType.Message: + ident("sizeOf" & getFieldTypeName(field)) + else: + ident("sizeOf" & $getFieldType(field)) + fname = newDotExpr(messageId, ident(getFieldName(field))) + number = getFieldNumber(field) + wiretype = ident(wiretype(field)) + + # TODO: packed + if isRepeated(field): + body.add quote do: + if `hasproc`(`messageId`): + for value in `fname`: + let + sizeOfField = `sizeofproc`(value) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + + sizeOfField + + sizeOfUint64(sizeOfField) + + tagSize + else: + if getFieldType(field) == FieldType.Message: + body.add quote do: + if `hasproc`(`messageId`): + let + sizeOfField = `sizeofproc`(`fname`) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + sizeOfField + tagSize + + sizeOfUint64(sizeOfField) + else: + body.add quote do: + if `hasproc`(`messageId`): + let + sizeOfField = `sizeofproc`(`fname`) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + sizeOfField + tagSize + + result = newProc(postfix(ident("sizeOf" & name), "*"), + @[ident("uint64"), newIdentDefs(messageId, ident(name))], + body) + macro generateMessageProcs*(x: typed): typed = let desc = getImpl(symbol(x)) @@ -380,6 +479,7 @@ macro generateMessageProcs*(x: typed): typed = add(result, generateWriteMessageProc(desc)) add(result, generateReadMessageProc(desc)) + add(result, generateSizeOfMessageProc(desc)) when defined(debug): hint(repr(result)) -- cgit v1.2.3