diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2018-03-27 23:16:15 +0300 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2018-03-27 23:16:15 +0300 |
| commit | 5e9000fcdcdb86981d52841abf1c4d5385ae2ba8 (patch) | |
| tree | 7ca8af8c114d14986a05ce0cd045c71a4b400d6c /src | |
| parent | d76ec81388454c24ee99a601830ed39cfa50063c (diff) | |
| download | nimpb-5e9000fcdcdb86981d52841abf1c4d5385ae2ba8.tar.gz nimpb-5e9000fcdcdb86981d52841abf1c4d5385ae2ba8.zip | |
Initial support for oneofs
Diffstat (limited to 'src')
| -rw-r--r-- | src/protobuf/gen.nim | 128 |
1 files changed, 111 insertions, 17 deletions
diff --git a/src/protobuf/gen.nim b/src/protobuf/gen.nim index fd3b3eb..c21565e 100644 --- a/src/protobuf/gen.nim +++ b/src/protobuf/gen.nim @@ -7,6 +7,7 @@ type MessageDesc* = object name*: string fields*: seq[FieldDesc] + oneofs*: seq[string] FieldLabel* {.pure.} = enum Optional = 1 @@ -20,6 +21,7 @@ type label*: FieldLabel typeName*: string packed*: bool + oneofIdx*: int EnumDesc* = object name*: string @@ -137,21 +139,81 @@ proc defaultValue(field: NimNode): NimNode = proc wiretype(field: NimNode): WireType = result = wiretype(getFieldType(field)) -proc fieldInitializer(objname: string, field: NimNode): NimNode = +# TODO: maybe not the best name for this +proc getFieldNameAST(objname: NimNode, field: NimNode, oneof: string): NimNode = + result = + if oneof != "": + newDotExpr(newDotExpr(objname, ident(oneof)), ident(getFieldName(field))) + else: + newDotExpr(objname, ident(getFieldName(field))) + +proc fieldInitializer(objname: NimNode, field: NimNode, oneof: string): NimNode = result = nnkAsgn.newTree( - nnkDotExpr.newTree( - newIdentNode(objname), - newIdentNode(getFieldName(field)) - ), + getFieldNameAST(objname, field, oneof), defaultValue(field) ) +proc oneofIndex(field: NimNode): int = + let node = findColonExpr(field, "oneofIdx") + if node == nil: + result = -1 + else: + result = int(node[1].intVal) + +proc oneofName(message, field: NimNode): string = + let index = oneofIndex(field) + + if index == -1: + return "" + + let oneofs = findColonExpr(message, "oneofs")[1] + + result = $oneofs[index] + +iterator oneofFields(message: NimNode, index: int): NimNode = + if index != -1: + for field in fields(message): + if oneofIndex(field) == index: + yield field + +proc generateOneofFields*(desc: NimNode, typeSection: NimNode) = + let + oneofs = findColonExpr(desc, "oneofs")[1] + messageName = getMessageName(desc) + + for index, oneof in oneofs: + let reclist = nnkRecList.newTree() + + for field in oneofFields(desc, index): + let ftype = getFullFieldType(field) + let name = ident(getFieldName(field)) + + add(reclist, newIdentDefs(postfix(name, "*"), ftype)) + + let typedef = nnkTypeDef.newTree( + nnkPragmaExpr.newTree( + postfix(ident(messageName & $oneof), "*"), + nnkPragma.newTree( + ident("union") + ) + ), + newEmptyNode(), + nnkObjectTy.newTree( + newEmptyNode(), + newEmptyNode(), + reclist + ) + ) + + add(typeSection, typedef) + macro generateMessageType*(desc: typed): typed = let impl = getImpl(symbol(desc)) typeSection = nnkTypeSection.newTree() typedef = nnkTypeDef.newTree() reclist = nnkRecList.newTree() + oneofs = findColonExpr(impl, "oneofs")[1] let name = getMessageName(impl) @@ -168,11 +230,18 @@ macro generateMessageType*(desc: typed): typed = for field in fields(impl): let ftype = getFullFieldType(field) let name = ident(getFieldName(field)) - add(reclist, newIdentDefs(postfix(name, "*"), ftype)) + if oneofIndex(field) == -1: + add(reclist, newIdentDefs(postfix(name, "*"), ftype)) + + for oneof in oneofs: + add(reclist, newIdentDefs(postfix(ident($oneof), "*"), + ident(name & $oneof))) add(reclist, nnkIdentDefs.newTree( ident("hasField"), ident("IntSet"), newEmptyNode())) + generateOneofFields(impl, typeSection) + result = newStmtList() add(result, typeSection) @@ -180,14 +249,17 @@ macro generateMessageType*(desc: typed): typed = hint(repr(result)) proc generateNewMessageProc(desc: NimNode): NimNode = - let body = newStmtList( - newCall(ident("new"), ident("result")) - ) + let + body = newStmtList( + newCall(ident("new"), ident("result")) + ) + resultId = ident("result") for field in fields(desc): - add(body, fieldInitializer("result", field)) + let oneofName = oneofName(desc, field) + add(body, fieldInitializer(resultId, field, oneofName)) - add(body, newAssignment(newDotExpr(ident("result"), ident("hasField")), + add(body, newAssignment(newDotExpr(resultId, ident("hasField")), newCall(ident("initIntSet")))) result = newProc(postfix(ident("new" & getMessageName(desc)), "*"), @@ -203,7 +275,7 @@ proc fieldProcIdent(prefix: string, field: NimNode): NimNode = proc generateClearFieldProc(desc, field: NimNode): NimNode = let messageId = ident("message") - fname = newDotExpr(messageId, ident(getFieldName(field))) + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) defvalue = defaultValue(field) hasField = newDotExpr(messageId, ident("hasField")) number = getFieldNumber(field) @@ -215,6 +287,17 @@ proc generateClearFieldProc(desc, field: NimNode): NimNode = `fname` = `defvalue` excl(`hasfield`, `number`) + # When clearing a field that is contained in a oneof, we should also clear + # the other fields. + for sibling in oneofFields(desc, oneofIndex(field)): + if sibling == field: + continue + let + number = getFieldNumber(sibling) + exclNode = quote do: + excl(`hasField`, `number`) + add(body(result), exclNode) + proc generateHasFieldProc(desc, field: NimNode): NimNode = let messageId = ident("message") @@ -233,7 +316,7 @@ proc generateSetFieldProc(desc, field: NimNode): NimNode = hasField = newDotExpr(messageId, ident("hasField")) number = getFieldNumber(field) valueId = ident("value") - fname = newDotExpr(messageId, ident(getFieldName(field))) + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) procName = fieldProcIdent("set", field) mtype = ident(getMessageName(desc)) ftype = getFullFieldType(field) @@ -243,6 +326,16 @@ proc generateSetFieldProc(desc, field: NimNode): NimNode = `fname` = `valueId` incl(`hasfield`, `number`) + # When setting a field that is in a oneof, we need to unset the other fields + for sibling in oneofFields(desc, oneofIndex(field)): + if sibling == field: + continue + let + number = getFieldNumber(sibling) + exclNode = quote do: + excl(`hasField`, `number`) + add(body(result), exclNode) + proc generateAddToFieldProc(desc, field: NimNode): NimNode = let procName = fieldProcIdent("add", field) @@ -262,13 +355,14 @@ proc generateAddToFieldProc(desc, field: NimNode): NimNode = proc ident(wt: WireType): NimNode = result = newDotExpr(ident("WireType"), ident($wt)) -proc genWriteField(field: NimNode): NimNode = +proc genWriteField(message, field: NimNode): NimNode = result = newStmtList() let number = getFieldNumber(field) writer = ident("write" & getFieldTypeAsString(field)) - fname = newDotExpr(ident("message"), ident(getFieldName(field))) + messageId = ident("message") + fname = getFieldNameAST(messageId, field, oneofName(message, field)) wiretype = ident(wiretype(field)) sizeproc = ident("sizeOf" & getFieldTypeAsString(field)) hasproc = ident(fieldProcName("has", field)) @@ -310,7 +404,7 @@ proc generateWriteMessageProc(desc: NimNode): NimNode = sizeproc = postfix(ident("sizeOf" & getMessageName(desc)), "*") for field in fields(desc): - add(body, genWriteField(field)) + add(body, genWriteField(desc, field)) result = quote do: proc `sizeproc`(`messageId`: `mtype`): uint64 @@ -410,7 +504,7 @@ proc generateSizeOfMessageProc(desc: NimNode): NimNode = let hasproc = ident(fieldProcName("has", field)) sizeofproc = ident("sizeOf" & getFieldTypeAsString(field)) - fname = newDotExpr(messageId, ident(getFieldName(field))) + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) number = getFieldNumber(field) wiretype = ident(wiretype(field)) |
