diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2018-03-24 18:29:11 +0200 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2018-03-24 18:29:11 +0200 |
| commit | 5b25e192705d4040b51665e49627f8ea037186ee (patch) | |
| tree | 1c4f0fed3098c28a2ea0200166bb35e09cb4438a | |
| parent | 0bc7b67059868af65d2158a8aeade5b6f777431b (diff) | |
| download | nimpb-5b25e192705d4040b51665e49627f8ea037186ee.tar.gz nimpb-5b25e192705d4040b51665e49627f8ea037186ee.zip | |
Initial support for embedded messages
| -rw-r--r-- | src/protobuf/gen.nim | 130 | ||||
| -rw-r--r-- | tests/test3.nim | 55 |
2 files changed, 170 insertions, 15 deletions
diff --git a/src/protobuf/gen.nim b/src/protobuf/gen.nim index 881b275..76d36fe 100644 --- a/src/protobuf/gen.nim +++ b/src/protobuf/gen.nim @@ -72,9 +72,16 @@ proc getFieldType(field: NimNode): FieldType = let node = findColonExpr(field, "ftype") result = FieldType(node[1].intVal) +proc getFieldTypeName(field: NimNode): string = + let node = findColonExpr(field, "typeName") + result = $node[1] + proc getFullFieldType(field: NimNode): NimNode = let ftype = getFieldType(field) - result = toNimNode(ftype) + if ftype == FieldType.Message: + result = ident(getFieldTypeName(field)) + else: + result = toNimNode(ftype) if isRepeated(field): result = nnkBracketExpr.newTree(ident("seq"), result) @@ -102,7 +109,7 @@ proc defaultValue(field: NimNode): NimNode = of FieldType.Bool: result = newLit(false) of FieldType.String: result = newLit("") of FieldType.Group: result = newLit("NOTIMPLEMENTED") - of FieldType.Message: result = newLit("TODO") + of FieldType.Message: result = newCall(ident("new" & getFieldTypeName(field))) of FieldType.Bytes: result = newCall(ident("bytes"), newLit("")) of FieldType.UInt32: result = newLit(0'u32) of FieldType.Enum: result = newLit("TODO") @@ -242,7 +249,11 @@ proc generateAddToFieldProc(desc, field: NimNode): NimNode = add(body, newCall("incl", newDotExpr(ident("message"), ident("hasField")), newLit(getFieldNumber(field)))) - let ftype = toNimNode(getFieldType(field)) + let ftype = + if getFieldType(field) == FieldType.Message: + ident(getFieldTypeName(field)) + else: + toNimNode(getFieldType(field)) result = newProc(postfix(ident("add" & capitalizeAscii(fieldName)), "*"), @[newEmptyNode(), newIdentDefs(ident("message"), @@ -258,7 +269,11 @@ proc genWriteField(field: NimNode): NimNode = let number = getFieldNumber(field) - writer = ident("write" & $getFieldType(field)) + writer = + if getFieldType(field) == FieldType.Message: + ident("write" & getFieldTypeName(field)) + else: + ident("write" & $getFieldType(field)) fname = newDotExpr(ident("message"), ident(getFieldName(field))) wiretype = ident(wiretype(field)) @@ -266,6 +281,9 @@ proc genWriteField(field: NimNode): NimNode = result.add quote do: writeTag(stream, `number`, `wiretype`) `writer`(stream, `fname`) + if getFieldType(field) == FieldType.Message: + insert(result[^1], 1, newCall(ident("writeVarint"), ident("stream"), + newCall(ident("sizeOf" & getFieldTypeName(field)), fname))) else: if isPacked(field): result.add quote do: @@ -274,10 +292,14 @@ proc genWriteField(field: NimNode): NimNode = for value in `fname`: `writer`(stream, value) else: + let valueId = ident("value") result.add quote do: - for value in `fname`: + for `valueId` in `fname`: writeTag(stream, `number`, `wiretype`) - `writer`(stream, value) + `writer`(stream, `valueId`) + if getFieldType(field) == FieldType.Message: + insert(result[^1][^1], 1, newCall(ident("writeVarint"), ident("stream"), + newCall(ident("sizeOf" & getFieldTypeName(field)), valueId))) proc generateWriteMessageProc(desc: NimNode): NimNode = let body = newStmtList() @@ -305,7 +327,7 @@ proc generateReadMessageProc(desc: NimNode): NimNode = let resultId = ident("result") let body = newStmtList( - newCall(ident("new"), resultId) + newAssignment(resultId, newCall(ident("new" & getMessageName(desc)))) ) let tagid = ident("tag") @@ -325,7 +347,11 @@ proc generateReadMessageProc(desc: NimNode): NimNode = let number = getFieldNumber(field) if isRepeated(field): let adder = ident("add" & capitalizeAscii(getFieldName(field))) - let reader = ident("read" & $getFieldType(field)) + let reader = + if getFieldType(field) == FieldType.Message: + ident("read" & getFieldTypeName(field)) + else: + ident("read" & $getFieldType(field)) if isNumeric(getFieldType(field)): add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: if `wiretypeId` == WireType.LengthDelimited: @@ -343,15 +369,36 @@ proc generateReadMessageProc(desc: NimNode): NimNode = `adder`(`resultId`, `reader`(stream)) )) else: - add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: - `adder`(`resultId`, `reader`(stream)) - )) + if getFieldType(field) == FieldType.Message: + add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: + let size = readVarint(stream) + let data = readStr(stream, int(size)) + let stream2 = newProtobufStream(newStringStream(data)) + `adder`(`resultId`, `reader`(stream2)) + )) + else: + add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: + `adder`(`resultId`, `reader`(stream)) + )) else: let setter = ident("set" & capitalizeAscii(getFieldName(field))) - let reader = ident("read" & $getFieldType(field)) - add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: - `setter`(`resultId`, `reader`(stream)) - )) + let reader = + if getFieldType(field) == FieldType.Message: + ident("read" & getFieldTypeName(field)) + else: + ident("read" & $getFieldType(field)) + if getFieldType(field) == FieldType.Message: + add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: + let size = readVarint(stream) + let data = readStr(stream, int(size)) + let stream2 = newProtobufStream(newStringStream(data)) + `setter`(`resultId`, `reader`(stream2)) + )) + else: + add(caseNode, nnkOfBranch.newTree(newLit(number), quote do: + `setter`(`resultId`, `reader`(stream)) + )) + # TODO: generate code to skip unknown fields add(caseNode, nnkElse.newTree(quote do: @@ -362,6 +409,58 @@ proc generateReadMessageProc(desc: NimNode): NimNode = @[ident(name), newIdentDefs(ident("stream"), ident("ProtobufStream"))], body) +proc generateSizeOfMessageProc(desc: NimNode): NimNode = + let + name = getMessageName(desc) + body = newStmtList() + messageId = ident("message") + resultId = ident("result") + + for field in fields(desc): + let + hasproc = ident("has" & capitalizeAscii(getFieldName(field))) + sizeofproc = + if getFieldType(field) == FieldType.Message: + ident("sizeOf" & getFieldTypeName(field)) + else: + ident("sizeOf" & $getFieldType(field)) + fname = newDotExpr(messageId, ident(getFieldName(field))) + number = getFieldNumber(field) + wiretype = ident(wiretype(field)) + + # TODO: packed + if isRepeated(field): + body.add quote do: + if `hasproc`(`messageId`): + for value in `fname`: + let + sizeOfField = `sizeofproc`(value) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + + sizeOfField + + sizeOfUint64(sizeOfField) + + tagSize + else: + if getFieldType(field) == FieldType.Message: + body.add quote do: + if `hasproc`(`messageId`): + let + sizeOfField = `sizeofproc`(`fname`) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + sizeOfField + tagSize + + sizeOfUint64(sizeOfField) + else: + body.add quote do: + if `hasproc`(`messageId`): + let + sizeOfField = `sizeofproc`(`fname`) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + sizeOfField + tagSize + + result = newProc(postfix(ident("sizeOf" & name), "*"), + @[ident("uint64"), newIdentDefs(messageId, ident(name))], + body) + macro generateMessageProcs*(x: typed): typed = let desc = getImpl(symbol(x)) @@ -380,6 +479,7 @@ macro generateMessageProcs*(x: typed): typed = add(result, generateWriteMessageProc(desc)) add(result, generateReadMessageProc(desc)) + add(result, generateSizeOfMessageProc(desc)) when defined(debug): hint(repr(result)) diff --git a/tests/test3.nim b/tests/test3.nim new file mode 100644 index 0000000..9c79b84 --- /dev/null +++ b/tests/test3.nim @@ -0,0 +1,55 @@ +import intsets + +import protobuf/types +import protobuf/gen +import protobuf/stream + +const + Test1Desc = MessageDesc( + name: "Test1", + fields: @[ + FieldDesc( + name: "a", + number: 1, + ftype: FieldType.Int32, + label: FieldLabel.Required, + typeName: "", + packed: true + ) + ] + ) + + Test3Desc = MessageDesc( + name: "Test3", + fields: @[ + FieldDesc( + name: "c", + number: 3, + ftype: FieldType.Message, + label: FieldLabel.Required, + typeName: "Test1", + packed: false + ) + ] + ) + +generateMessageType(Test1Desc) +generateMessageProcs(Test1Desc) + +generateMessageType(Test3Desc) +generateMessageProcs(Test3Desc) + +import strutils +let message = newTest3() +let t1 = newTest1() +setA(t1, 150) +setC(message, t1) +let ss = newStringStream() +let pbs = newProtobufStream(ss) +writeTest3(pbs, message) +for b in ss.data: + echo(toHex(int(b), 2)) + +setPosition(pbs, 0) +let message2 = readTest3(pbs) +echo(message2.c.a) |
