aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/protobuf/gen.nim128
1 files changed, 111 insertions, 17 deletions
diff --git a/src/protobuf/gen.nim b/src/protobuf/gen.nim
index fd3b3eb..c21565e 100644
--- a/src/protobuf/gen.nim
+++ b/src/protobuf/gen.nim
@@ -7,6 +7,7 @@ type
MessageDesc* = object
name*: string
fields*: seq[FieldDesc]
+ oneofs*: seq[string]
FieldLabel* {.pure.} = enum
Optional = 1
@@ -20,6 +21,7 @@ type
label*: FieldLabel
typeName*: string
packed*: bool
+ oneofIdx*: int
EnumDesc* = object
name*: string
@@ -137,21 +139,81 @@ proc defaultValue(field: NimNode): NimNode =
proc wiretype(field: NimNode): WireType =
result = wiretype(getFieldType(field))
-proc fieldInitializer(objname: string, field: NimNode): NimNode =
+# TODO: maybe not the best name for this
+proc getFieldNameAST(objname: NimNode, field: NimNode, oneof: string): NimNode =
+ result =
+ if oneof != "":
+ newDotExpr(newDotExpr(objname, ident(oneof)), ident(getFieldName(field)))
+ else:
+ newDotExpr(objname, ident(getFieldName(field)))
+
+proc fieldInitializer(objname: NimNode, field: NimNode, oneof: string): NimNode =
result = nnkAsgn.newTree(
- nnkDotExpr.newTree(
- newIdentNode(objname),
- newIdentNode(getFieldName(field))
- ),
+ getFieldNameAST(objname, field, oneof),
defaultValue(field)
)
+proc oneofIndex(field: NimNode): int =
+ let node = findColonExpr(field, "oneofIdx")
+ if node == nil:
+ result = -1
+ else:
+ result = int(node[1].intVal)
+
+proc oneofName(message, field: NimNode): string =
+ let index = oneofIndex(field)
+
+ if index == -1:
+ return ""
+
+ let oneofs = findColonExpr(message, "oneofs")[1]
+
+ result = $oneofs[index]
+
+iterator oneofFields(message: NimNode, index: int): NimNode =
+ if index != -1:
+ for field in fields(message):
+ if oneofIndex(field) == index:
+ yield field
+
+proc generateOneofFields*(desc: NimNode, typeSection: NimNode) =
+ let
+ oneofs = findColonExpr(desc, "oneofs")[1]
+ messageName = getMessageName(desc)
+
+ for index, oneof in oneofs:
+ let reclist = nnkRecList.newTree()
+
+ for field in oneofFields(desc, index):
+ let ftype = getFullFieldType(field)
+ let name = ident(getFieldName(field))
+
+ add(reclist, newIdentDefs(postfix(name, "*"), ftype))
+
+ let typedef = nnkTypeDef.newTree(
+ nnkPragmaExpr.newTree(
+ postfix(ident(messageName & $oneof), "*"),
+ nnkPragma.newTree(
+ ident("union")
+ )
+ ),
+ newEmptyNode(),
+ nnkObjectTy.newTree(
+ newEmptyNode(),
+ newEmptyNode(),
+ reclist
+ )
+ )
+
+ add(typeSection, typedef)
+
macro generateMessageType*(desc: typed): typed =
let
impl = getImpl(symbol(desc))
typeSection = nnkTypeSection.newTree()
typedef = nnkTypeDef.newTree()
reclist = nnkRecList.newTree()
+ oneofs = findColonExpr(impl, "oneofs")[1]
let name = getMessageName(impl)
@@ -168,11 +230,18 @@ macro generateMessageType*(desc: typed): typed =
for field in fields(impl):
let ftype = getFullFieldType(field)
let name = ident(getFieldName(field))
- add(reclist, newIdentDefs(postfix(name, "*"), ftype))
+ if oneofIndex(field) == -1:
+ add(reclist, newIdentDefs(postfix(name, "*"), ftype))
+
+ for oneof in oneofs:
+ add(reclist, newIdentDefs(postfix(ident($oneof), "*"),
+ ident(name & $oneof)))
add(reclist, nnkIdentDefs.newTree(
ident("hasField"), ident("IntSet"), newEmptyNode()))
+ generateOneofFields(impl, typeSection)
+
result = newStmtList()
add(result, typeSection)
@@ -180,14 +249,17 @@ macro generateMessageType*(desc: typed): typed =
hint(repr(result))
proc generateNewMessageProc(desc: NimNode): NimNode =
- let body = newStmtList(
- newCall(ident("new"), ident("result"))
- )
+ let
+ body = newStmtList(
+ newCall(ident("new"), ident("result"))
+ )
+ resultId = ident("result")
for field in fields(desc):
- add(body, fieldInitializer("result", field))
+ let oneofName = oneofName(desc, field)
+ add(body, fieldInitializer(resultId, field, oneofName))
- add(body, newAssignment(newDotExpr(ident("result"), ident("hasField")),
+ add(body, newAssignment(newDotExpr(resultId, ident("hasField")),
newCall(ident("initIntSet"))))
result = newProc(postfix(ident("new" & getMessageName(desc)), "*"),
@@ -203,7 +275,7 @@ proc fieldProcIdent(prefix: string, field: NimNode): NimNode =
proc generateClearFieldProc(desc, field: NimNode): NimNode =
let
messageId = ident("message")
- fname = newDotExpr(messageId, ident(getFieldName(field)))
+ fname = getFieldNameAST(messageId, field, oneofName(desc, field))
defvalue = defaultValue(field)
hasField = newDotExpr(messageId, ident("hasField"))
number = getFieldNumber(field)
@@ -215,6 +287,17 @@ proc generateClearFieldProc(desc, field: NimNode): NimNode =
`fname` = `defvalue`
excl(`hasfield`, `number`)
+ # When clearing a field that is contained in a oneof, we should also clear
+ # the other fields.
+ for sibling in oneofFields(desc, oneofIndex(field)):
+ if sibling == field:
+ continue
+ let
+ number = getFieldNumber(sibling)
+ exclNode = quote do:
+ excl(`hasField`, `number`)
+ add(body(result), exclNode)
+
proc generateHasFieldProc(desc, field: NimNode): NimNode =
let
messageId = ident("message")
@@ -233,7 +316,7 @@ proc generateSetFieldProc(desc, field: NimNode): NimNode =
hasField = newDotExpr(messageId, ident("hasField"))
number = getFieldNumber(field)
valueId = ident("value")
- fname = newDotExpr(messageId, ident(getFieldName(field)))
+ fname = getFieldNameAST(messageId, field, oneofName(desc, field))
procName = fieldProcIdent("set", field)
mtype = ident(getMessageName(desc))
ftype = getFullFieldType(field)
@@ -243,6 +326,16 @@ proc generateSetFieldProc(desc, field: NimNode): NimNode =
`fname` = `valueId`
incl(`hasfield`, `number`)
+ # When setting a field that is in a oneof, we need to unset the other fields
+ for sibling in oneofFields(desc, oneofIndex(field)):
+ if sibling == field:
+ continue
+ let
+ number = getFieldNumber(sibling)
+ exclNode = quote do:
+ excl(`hasField`, `number`)
+ add(body(result), exclNode)
+
proc generateAddToFieldProc(desc, field: NimNode): NimNode =
let
procName = fieldProcIdent("add", field)
@@ -262,13 +355,14 @@ proc generateAddToFieldProc(desc, field: NimNode): NimNode =
proc ident(wt: WireType): NimNode =
result = newDotExpr(ident("WireType"), ident($wt))
-proc genWriteField(field: NimNode): NimNode =
+proc genWriteField(message, field: NimNode): NimNode =
result = newStmtList()
let
number = getFieldNumber(field)
writer = ident("write" & getFieldTypeAsString(field))
- fname = newDotExpr(ident("message"), ident(getFieldName(field)))
+ messageId = ident("message")
+ fname = getFieldNameAST(messageId, field, oneofName(message, field))
wiretype = ident(wiretype(field))
sizeproc = ident("sizeOf" & getFieldTypeAsString(field))
hasproc = ident(fieldProcName("has", field))
@@ -310,7 +404,7 @@ proc generateWriteMessageProc(desc: NimNode): NimNode =
sizeproc = postfix(ident("sizeOf" & getMessageName(desc)), "*")
for field in fields(desc):
- add(body, genWriteField(field))
+ add(body, genWriteField(desc, field))
result = quote do:
proc `sizeproc`(`messageId`: `mtype`): uint64
@@ -410,7 +504,7 @@ proc generateSizeOfMessageProc(desc: NimNode): NimNode =
let
hasproc = ident(fieldProcName("has", field))
sizeofproc = ident("sizeOf" & getFieldTypeAsString(field))
- fname = newDotExpr(messageId, ident(getFieldName(field)))
+ fname = getFieldNameAST(messageId, field, oneofName(desc, field))
number = getFieldNumber(field)
wiretype = ident(wiretype(field))