aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOskari Timperi <oskari.timperi@iki.fi>2018-03-24 18:29:11 +0200
committerOskari Timperi <oskari.timperi@iki.fi>2018-03-24 18:29:11 +0200
commit5b25e192705d4040b51665e49627f8ea037186ee (patch)
tree1c4f0fed3098c28a2ea0200166bb35e09cb4438a
parent0bc7b67059868af65d2158a8aeade5b6f777431b (diff)
downloadnimpb-5b25e192705d4040b51665e49627f8ea037186ee.tar.gz
nimpb-5b25e192705d4040b51665e49627f8ea037186ee.zip
Initial support for embedded messages
-rw-r--r--src/protobuf/gen.nim130
-rw-r--r--tests/test3.nim55
2 files changed, 170 insertions, 15 deletions
diff --git a/src/protobuf/gen.nim b/src/protobuf/gen.nim
index 881b275..76d36fe 100644
--- a/src/protobuf/gen.nim
+++ b/src/protobuf/gen.nim
@@ -72,9 +72,16 @@ proc getFieldType(field: NimNode): FieldType =
let node = findColonExpr(field, "ftype")
result = FieldType(node[1].intVal)
+proc getFieldTypeName(field: NimNode): string =
+ let node = findColonExpr(field, "typeName")
+ result = $node[1]
+
proc getFullFieldType(field: NimNode): NimNode =
let ftype = getFieldType(field)
- result = toNimNode(ftype)
+ if ftype == FieldType.Message:
+ result = ident(getFieldTypeName(field))
+ else:
+ result = toNimNode(ftype)
if isRepeated(field):
result = nnkBracketExpr.newTree(ident("seq"), result)
@@ -102,7 +109,7 @@ proc defaultValue(field: NimNode): NimNode =
of FieldType.Bool: result = newLit(false)
of FieldType.String: result = newLit("")
of FieldType.Group: result = newLit("NOTIMPLEMENTED")
- of FieldType.Message: result = newLit("TODO")
+ of FieldType.Message: result = newCall(ident("new" & getFieldTypeName(field)))
of FieldType.Bytes: result = newCall(ident("bytes"), newLit(""))
of FieldType.UInt32: result = newLit(0'u32)
of FieldType.Enum: result = newLit("TODO")
@@ -242,7 +249,11 @@ proc generateAddToFieldProc(desc, field: NimNode): NimNode =
add(body, newCall("incl", newDotExpr(ident("message"), ident("hasField")),
newLit(getFieldNumber(field))))
- let ftype = toNimNode(getFieldType(field))
+ let ftype =
+ if getFieldType(field) == FieldType.Message:
+ ident(getFieldTypeName(field))
+ else:
+ toNimNode(getFieldType(field))
result = newProc(postfix(ident("add" & capitalizeAscii(fieldName)), "*"),
@[newEmptyNode(), newIdentDefs(ident("message"),
@@ -258,7 +269,11 @@ proc genWriteField(field: NimNode): NimNode =
let
number = getFieldNumber(field)
- writer = ident("write" & $getFieldType(field))
+ writer =
+ if getFieldType(field) == FieldType.Message:
+ ident("write" & getFieldTypeName(field))
+ else:
+ ident("write" & $getFieldType(field))
fname = newDotExpr(ident("message"), ident(getFieldName(field)))
wiretype = ident(wiretype(field))
@@ -266,6 +281,9 @@ proc genWriteField(field: NimNode): NimNode =
result.add quote do:
writeTag(stream, `number`, `wiretype`)
`writer`(stream, `fname`)
+ if getFieldType(field) == FieldType.Message:
+ insert(result[^1], 1, newCall(ident("writeVarint"), ident("stream"),
+ newCall(ident("sizeOf" & getFieldTypeName(field)), fname)))
else:
if isPacked(field):
result.add quote do:
@@ -274,10 +292,14 @@ proc genWriteField(field: NimNode): NimNode =
for value in `fname`:
`writer`(stream, value)
else:
+ let valueId = ident("value")
result.add quote do:
- for value in `fname`:
+ for `valueId` in `fname`:
writeTag(stream, `number`, `wiretype`)
- `writer`(stream, value)
+ `writer`(stream, `valueId`)
+ if getFieldType(field) == FieldType.Message:
+ insert(result[^1][^1], 1, newCall(ident("writeVarint"), ident("stream"),
+ newCall(ident("sizeOf" & getFieldTypeName(field)), valueId)))
proc generateWriteMessageProc(desc: NimNode): NimNode =
let body = newStmtList()
@@ -305,7 +327,7 @@ proc generateReadMessageProc(desc: NimNode): NimNode =
let resultId = ident("result")
let body = newStmtList(
- newCall(ident("new"), resultId)
+ newAssignment(resultId, newCall(ident("new" & getMessageName(desc))))
)
let tagid = ident("tag")
@@ -325,7 +347,11 @@ proc generateReadMessageProc(desc: NimNode): NimNode =
let number = getFieldNumber(field)
if isRepeated(field):
let adder = ident("add" & capitalizeAscii(getFieldName(field)))
- let reader = ident("read" & $getFieldType(field))
+ let reader =
+ if getFieldType(field) == FieldType.Message:
+ ident("read" & getFieldTypeName(field))
+ else:
+ ident("read" & $getFieldType(field))
if isNumeric(getFieldType(field)):
add(caseNode, nnkOfBranch.newTree(newLit(number), quote do:
if `wiretypeId` == WireType.LengthDelimited:
@@ -343,15 +369,36 @@ proc generateReadMessageProc(desc: NimNode): NimNode =
`adder`(`resultId`, `reader`(stream))
))
else:
- add(caseNode, nnkOfBranch.newTree(newLit(number), quote do:
- `adder`(`resultId`, `reader`(stream))
- ))
+ if getFieldType(field) == FieldType.Message:
+ add(caseNode, nnkOfBranch.newTree(newLit(number), quote do:
+ let size = readVarint(stream)
+ let data = readStr(stream, int(size))
+ let stream2 = newProtobufStream(newStringStream(data))
+ `adder`(`resultId`, `reader`(stream2))
+ ))
+ else:
+ add(caseNode, nnkOfBranch.newTree(newLit(number), quote do:
+ `adder`(`resultId`, `reader`(stream))
+ ))
else:
let setter = ident("set" & capitalizeAscii(getFieldName(field)))
- let reader = ident("read" & $getFieldType(field))
- add(caseNode, nnkOfBranch.newTree(newLit(number), quote do:
- `setter`(`resultId`, `reader`(stream))
- ))
+ let reader =
+ if getFieldType(field) == FieldType.Message:
+ ident("read" & getFieldTypeName(field))
+ else:
+ ident("read" & $getFieldType(field))
+ if getFieldType(field) == FieldType.Message:
+ add(caseNode, nnkOfBranch.newTree(newLit(number), quote do:
+ let size = readVarint(stream)
+ let data = readStr(stream, int(size))
+ let stream2 = newProtobufStream(newStringStream(data))
+ `setter`(`resultId`, `reader`(stream2))
+ ))
+ else:
+ add(caseNode, nnkOfBranch.newTree(newLit(number), quote do:
+ `setter`(`resultId`, `reader`(stream))
+ ))
+
# TODO: generate code to skip unknown fields
add(caseNode, nnkElse.newTree(quote do:
@@ -362,6 +409,58 @@ proc generateReadMessageProc(desc: NimNode): NimNode =
@[ident(name), newIdentDefs(ident("stream"), ident("ProtobufStream"))],
body)
+proc generateSizeOfMessageProc(desc: NimNode): NimNode =
+ let
+ name = getMessageName(desc)
+ body = newStmtList()
+ messageId = ident("message")
+ resultId = ident("result")
+
+ for field in fields(desc):
+ let
+ hasproc = ident("has" & capitalizeAscii(getFieldName(field)))
+ sizeofproc =
+ if getFieldType(field) == FieldType.Message:
+ ident("sizeOf" & getFieldTypeName(field))
+ else:
+ ident("sizeOf" & $getFieldType(field))
+ fname = newDotExpr(messageId, ident(getFieldName(field)))
+ number = getFieldNumber(field)
+ wiretype = ident(wiretype(field))
+
+ # TODO: packed
+ if isRepeated(field):
+ body.add quote do:
+ if `hasproc`(`messageId`):
+ for value in `fname`:
+ let
+ sizeOfField = `sizeofproc`(value)
+ tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`)))
+ `resultId` = `resultId` +
+ sizeOfField +
+ sizeOfUint64(sizeOfField) +
+ tagSize
+ else:
+ if getFieldType(field) == FieldType.Message:
+ body.add quote do:
+ if `hasproc`(`messageId`):
+ let
+ sizeOfField = `sizeofproc`(`fname`)
+ tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`)))
+ `resultId` = `resultId` + sizeOfField + tagSize +
+ sizeOfUint64(sizeOfField)
+ else:
+ body.add quote do:
+ if `hasproc`(`messageId`):
+ let
+ sizeOfField = `sizeofproc`(`fname`)
+ tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`)))
+ `resultId` = `resultId` + sizeOfField + tagSize
+
+ result = newProc(postfix(ident("sizeOf" & name), "*"),
+ @[ident("uint64"), newIdentDefs(messageId, ident(name))],
+ body)
+
macro generateMessageProcs*(x: typed): typed =
let
desc = getImpl(symbol(x))
@@ -380,6 +479,7 @@ macro generateMessageProcs*(x: typed): typed =
add(result, generateWriteMessageProc(desc))
add(result, generateReadMessageProc(desc))
+ add(result, generateSizeOfMessageProc(desc))
when defined(debug):
hint(repr(result))
diff --git a/tests/test3.nim b/tests/test3.nim
new file mode 100644
index 0000000..9c79b84
--- /dev/null
+++ b/tests/test3.nim
@@ -0,0 +1,55 @@
+import intsets
+
+import protobuf/types
+import protobuf/gen
+import protobuf/stream
+
+const
+ Test1Desc = MessageDesc(
+ name: "Test1",
+ fields: @[
+ FieldDesc(
+ name: "a",
+ number: 1,
+ ftype: FieldType.Int32,
+ label: FieldLabel.Required,
+ typeName: "",
+ packed: true
+ )
+ ]
+ )
+
+ Test3Desc = MessageDesc(
+ name: "Test3",
+ fields: @[
+ FieldDesc(
+ name: "c",
+ number: 3,
+ ftype: FieldType.Message,
+ label: FieldLabel.Required,
+ typeName: "Test1",
+ packed: false
+ )
+ ]
+ )
+
+generateMessageType(Test1Desc)
+generateMessageProcs(Test1Desc)
+
+generateMessageType(Test3Desc)
+generateMessageProcs(Test3Desc)
+
+import strutils
+let message = newTest3()
+let t1 = newTest1()
+setA(t1, 150)
+setC(message, t1)
+let ss = newStringStream()
+let pbs = newProtobufStream(ss)
+writeTest3(pbs, message)
+for b in ss.data:
+ echo(toHex(int(b), 2))
+
+setPosition(pbs, 0)
+let message2 = readTest3(pbs)
+echo(message2.c.a)