diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2018-04-03 21:14:59 +0300 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2018-04-03 21:14:59 +0300 |
| commit | 3b20e280d04708d505ab5770e8ea09aa6eab5aaf (patch) | |
| tree | f3100efa2856d45e4e5504ece9df12832a8dc2dc /generator | |
| parent | f27b133ad18d3884fb305083b08bceba18730fd7 (diff) | |
| download | nimpb-3b20e280d04708d505ab5770e8ea09aa6eab5aaf.tar.gz nimpb-3b20e280d04708d505ab5770e8ea09aa6eab5aaf.zip | |
Move src/protobuf/gen.nim to generator/gen.nim
Diffstat (limited to 'generator')
| -rw-r--r-- | generator/descriptor_pb.nim | 2 | ||||
| -rw-r--r-- | generator/gen.nim | 655 | ||||
| -rw-r--r-- | generator/plugin_pb.nim | 2 | ||||
| -rw-r--r-- | generator/protoc_gen_nim.nim | 3 |
4 files changed, 659 insertions, 3 deletions
diff --git a/generator/descriptor_pb.nim b/generator/descriptor_pb.nim index 433a720..8eeecb6 100644 --- a/generator/descriptor_pb.nim +++ b/generator/descriptor_pb.nim @@ -1,6 +1,6 @@ import intsets -import protobuf/gen +import gen import protobuf/protobuf const diff --git a/generator/gen.nim b/generator/gen.nim new file mode 100644 index 0000000..c51b746 --- /dev/null +++ b/generator/gen.nim @@ -0,0 +1,655 @@ +import macros +import strutils + +import protobuf/protobuf + +type + MessageDesc* = object + name*: string + fields*: seq[FieldDesc] + oneofs*: seq[string] + + FieldLabel* {.pure.} = enum + Optional = 1 + Required + Repeated + + FieldDesc* = object + name*: string + number*: int + ftype*: FieldType + label*: FieldLabel + typeName*: string + packed*: bool + oneofIdx*: int + + EnumDesc* = object + name*: string + values*: seq[EnumValueDesc] + + EnumValueDesc* = object + name*: string + number*: int + +proc findColonExpr(parent: NimNode, s: string): NimNode = + for child in parent: + if child.kind != nnkExprColonExpr: + continue + + if $child[0] == s: + return child + +proc getMessageName(desc: NimNode): string = + let node = findColonExpr(desc, "name") + result = $node[1] + +iterator fields(desc: NimNode): NimNode = + let node = findColonExpr(desc, "fields") + for field in node[1]: + yield field + +proc isRepeated(field: NimNode): bool = + let node = findColonExpr(field, "label") + let value = FieldLabel(node[1].intVal) + result = value == FieldLabel.Repeated + +proc isPacked(field: NimNode): bool = + let node = findColonExpr(field, "packed") + result = bool(node[1].intVal) + +proc getFieldType(field: NimNode): FieldType = + let node = findColonExpr(field, "ftype") + result = FieldType(node[1].intVal) + +proc isMessage(field: NimNode): bool = + result = getFieldType(field) == FieldType.Message + +proc isEnum(field: NimNode): bool = + result = getFieldType(field) == FieldType.Enum + +proc getFieldTypeName(field: NimNode): string = + let node = findColonExpr(field, "typeName") + result = $node[1] + +proc getFieldTypeAsString(field: NimNode): string = + if isMessage(field) or isEnum(field): + result = getFieldTypeName(field) + else: + case getFieldType(field) + of FieldType.Double: result = "float64" + of FieldType.Float: result = "float32" + of FieldType.Int64: result = "int64" + of FieldType.UInt64: result = "uint64" + of FieldType.Int32: result = "int32" + of FieldType.Fixed64: result = "uint64" + of FieldType.Fixed32: result = "uint32" + of FieldType.Bool: result = "bool" + of FieldType.String: result = "string" + of FieldType.Bytes: result = "bytes" + of FieldType.UInt32: result = "uint32" + of FieldType.SFixed32: result = "int32" + of FieldType.SFixed64: result = "int64" + of FieldType.SInt32: result = "int32" + of FieldType.SInt64: result = "int64" + else: result = "AYBABTU" + +proc getFullFieldType(field: NimNode): NimNode = + result = ident(getFieldTypeAsString(field)) + if isRepeated(field): + result = nnkBracketExpr.newTree(ident("seq"), result) + +proc getFieldName(field: NimNode): string = + let node = findColonExpr(field, "name") + result = $node[1] + +proc getFieldNumber(field: NimNode): int = + result = int(findColonExpr(field, "number")[1].intVal) + +proc defaultValue(field: NimNode): NimNode = + # TODO: check if there is a default value specified for the field + + if isRepeated(field): + return nnkPrefix.newTree(newIdentNode("@"), nnkBracket.newTree()) + + case getFieldType(field) + of FieldType.Double: result = newLit(0.0'f64) + of FieldType.Float: result = newLit(0.0'f32) + of FieldType.Int64: result = newLit(0'i64) + of FieldType.UInt64: result = newLit(0'u64) + of FieldType.Int32: result = newLit(0'i32) + of FieldType.Fixed64: result = newLit(0'u64) + of FieldType.Fixed32: result = newLit(0'u32) + of FieldType.Bool: result = newLit(false) + of FieldType.String: result = newLit("") + of FieldType.Group: result = newLit("NOTIMPLEMENTED") + of FieldType.Message: result = newCall(ident("new" & getFieldTypeAsString(field))) + of FieldType.Bytes: result = newCall(ident("bytes"), newLit("")) + of FieldType.UInt32: result = newLit(0'u32) + of FieldType.Enum: + let + descId = ident(getFieldTypeAsString(field) & "Desc") + nameId = ident(getFieldTypeAsString(field)) + result = quote do: + `nameId`(`descId`.values[0].number) + of FieldType.SFixed32: result = newLit(0'u32) + of FieldType.SFixed64: result = newLit(0'u32) + of FieldType.SInt32: result = newLit(0) + of FieldType.SInt64: result = newLit(0) + +proc wiretype(field: NimNode): WireType = + result = wiretype(getFieldType(field)) + +# TODO: maybe not the best name for this +proc getFieldNameAST(objname: NimNode, field: NimNode, oneof: string): NimNode = + result = + if oneof != "": + newDotExpr(newDotExpr(objname, ident(oneof)), ident(getFieldName(field))) + else: + newDotExpr(objname, ident(getFieldName(field))) + +proc fieldInitializer(objname: NimNode, field: NimNode, oneof: string): NimNode = + result = nnkAsgn.newTree( + getFieldNameAST(objname, field, oneof), + defaultValue(field) + ) + +proc oneofIndex(field: NimNode): int = + let node = findColonExpr(field, "oneofIdx") + if node == nil: + result = -1 + else: + result = int(node[1].intVal) + +proc oneofName(message, field: NimNode): string = + let index = oneofIndex(field) + + if index == -1: + return "" + + let oneofs = findColonExpr(message, "oneofs")[1] + + result = $oneofs[index] + +iterator oneofFields(message: NimNode, index: int): NimNode = + if index != -1: + for field in fields(message): + if oneofIndex(field) == index: + yield field + +proc generateOneofFields*(desc: NimNode, typeSection: NimNode) = + let + oneofs = findColonExpr(desc, "oneofs")[1] + messageName = getMessageName(desc) + + for index, oneof in oneofs: + let reclist = nnkRecList.newTree() + + for field in oneofFields(desc, index): + let ftype = getFullFieldType(field) + let name = ident(getFieldName(field)) + + add(reclist, newIdentDefs(postfix(name, "*"), ftype)) + + let typedef = nnkTypeDef.newTree( + nnkPragmaExpr.newTree( + postfix(ident(messageName & $oneof), "*"), + nnkPragma.newTree( + ident("union") + ) + ), + newEmptyNode(), + nnkObjectTy.newTree( + newEmptyNode(), + newEmptyNode(), + reclist + ) + ) + + add(typeSection, typedef) + +macro generateMessageType*(desc: typed): typed = + let + impl = getImpl(symbol(desc)) + typeSection = nnkTypeSection.newTree() + typedef = nnkTypeDef.newTree() + reclist = nnkRecList.newTree() + oneofs = findColonExpr(impl, "oneofs")[1] + + let name = getMessageName(impl) + + let typedefRef = nnkTypeDef.newTree(postfix(newIdentNode(name), "*"), newEmptyNode(), + nnkRefTy.newTree(newIdentNode(name & "Obj"))) + add(typeSection, typedefRef) + + add(typeSection, typedef) + + add(typedef, postfix(ident(name & "Obj"), "*")) + add(typedef, newEmptyNode()) + add(typedef, nnkObjectTy.newTree(newEmptyNode(), newEmptyNode(), reclist)) + + for field in fields(impl): + let ftype = getFullFieldType(field) + let name = ident(getFieldName(field)) + if oneofIndex(field) == -1: + add(reclist, newIdentDefs(postfix(name, "*"), ftype)) + + for oneof in oneofs: + add(reclist, newIdentDefs(postfix(ident($oneof), "*"), + ident(name & $oneof))) + + add(reclist, nnkIdentDefs.newTree( + ident("hasField"), ident("IntSet"), newEmptyNode())) + + generateOneofFields(impl, typeSection) + + result = newStmtList() + add(result, typeSection) + + when defined(debug): + hint(repr(result)) + +proc generateNewMessageProc(desc: NimNode): NimNode = + let + body = newStmtList( + newCall(ident("new"), ident("result")) + ) + resultId = ident("result") + + for field in fields(desc): + let oneofName = oneofName(desc, field) + add(body, fieldInitializer(resultId, field, oneofName)) + + add(body, newAssignment(newDotExpr(resultId, ident("hasField")), + newCall(ident("initIntSet")))) + + result = newProc(postfix(ident("new" & getMessageName(desc)), "*"), + @[ident(getMessageName(desc))], + body) + +proc fieldProcName(prefix: string, field: NimNode): string = + result = prefix & capitalizeAscii(getFieldName(field)) + +proc fieldProcIdent(prefix: string, field: NimNode): NimNode = + result = postfix(ident(fieldProcName(prefix, field)), "*") + +proc generateClearFieldProc(desc, field: NimNode): NimNode = + let + messageId = ident("message") + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) + defvalue = defaultValue(field) + hasField = newDotExpr(messageId, ident("hasField")) + number = getFieldNumber(field) + procName = fieldProcIdent("clear", field) + mtype = ident(getMessageName(desc)) + + result = quote do: + proc `procName`(`messageId`: `mtype`) = + `fname` = `defvalue` + excl(`hasfield`, `number`) + + # When clearing a field that is contained in a oneof, we should also clear + # the other fields. + for sibling in oneofFields(desc, oneofIndex(field)): + if sibling == field: + continue + let + number = getFieldNumber(sibling) + exclNode = quote do: + excl(`hasField`, `number`) + add(body(result), exclNode) + +proc generateHasFieldProc(desc, field: NimNode): NimNode = + let + messageId = ident("message") + hasField = newDotExpr(messageId, ident("hasField")) + number = getFieldNumber(field) + mtype = ident(getMessageName(desc)) + procName = fieldProcIdent("has", field) + + result = quote do: + proc `procName`(`messageId`: `mtype`): bool = + contains(`hasfield`, `number`) + +proc generateSetFieldProc(desc, field: NimNode): NimNode = + let + messageId = ident("message") + hasField = newDotExpr(messageId, ident("hasField")) + number = getFieldNumber(field) + valueId = ident("value") + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) + procName = fieldProcIdent("set", field) + mtype = ident(getMessageName(desc)) + ftype = getFullFieldType(field) + + result = quote do: + proc `procName`(`messageId`: `mtype`, `valueId`: `ftype`) = + `fname` = `valueId` + incl(`hasfield`, `number`) + + # When setting a field that is in a oneof, we need to unset the other fields + for sibling in oneofFields(desc, oneofIndex(field)): + if sibling == field: + continue + let + number = getFieldNumber(sibling) + exclNode = quote do: + excl(`hasField`, `number`) + add(body(result), exclNode) + +proc generateAddToFieldProc(desc, field: NimNode): NimNode = + let + procName = fieldProcIdent("add", field) + messageId = ident("message") + mtype = ident(getMessageName(desc)) + valueId = ident("value") + ftype = ident(getFieldTypeAsString(field)) + hasField = newDotExpr(messageId, ident("hasField")) + number = getFieldNumber(field) + fname = newDotExpr(messageId, ident(getFieldName(field))) + + result = quote do: + proc `procName`(`messageId`: `mtype`, `valueId`: `ftype`) = + add(`fname`, `valueId`) + incl(`hasfield`, `number`) + +proc ident(wt: WireType): NimNode = + result = newDotExpr(ident("WireType"), ident($wt)) + +proc genWriteField(message, field: NimNode): NimNode = + result = newStmtList() + + let + number = getFieldNumber(field) + writer = ident("write" & getFieldTypeAsString(field)) + messageId = ident("message") + fname = getFieldNameAST(messageId, field, oneofName(message, field)) + wiretype = ident(wiretype(field)) + sizeproc = ident("sizeOf" & getFieldTypeAsString(field)) + hasproc = ident(fieldProcName("has", field)) + + if not isRepeated(field): + result.add quote do: + if `hasproc`(message): + writeTag(stream, `number`, `wiretype`) + `writer`(stream, `fname`) + if isMessage(field): + insert(result[0][0][1], 1, quote do: + writeVarint(stream, `sizeproc`(`fname`)) + ) + else: + let valueId = ident("value") + if isPacked(field): + result.add quote do: + writeTag(stream, `number`, WireType.LengthDelimited) + writeVarInt(stream, packedFieldSize(`fname`, `wiretype`)) + for `valueId` in `fname`: + `writer`(stream, `valueId`) + else: + result.add quote do: + for `valueId` in `fname`: + writeTag(stream, `number`, `wiretype`) + `writer`(stream, `valueId`) + if isMessage(field): + insert(result[^1][^1], 1, quote do: + writeVarint(stream, `sizeproc`(`valueId`)) + ) + +proc generateWriteMessageProc(desc: NimNode): NimNode = + let + messageId = ident("message") + mtype = ident(getMessageName(desc)) + procName = postfix(ident("write" & getMessageName(desc)), "*") + body = newStmtList() + stream = ident("stream") + sizeproc = postfix(ident("sizeOf" & getMessageName(desc)), "*") + + for field in fields(desc): + add(body, genWriteField(desc, field)) + + result = quote do: + proc `sizeproc`(`messageId`: `mtype`): uint64 + + proc `procName`(`stream`: ProtobufStream, `messageId`: `mtype`) = + `body` + +proc generateReadMessageProc(desc: NimNode): NimNode = + let + procName = postfix(ident("read" & getMessageName(desc)), "*") + newproc = ident("new" & getMessageName(desc)) + streamId = ident("stream") + mtype = ident(getMessageName(desc)) + tagId = ident("tag") + wiretypeId = ident("wiretype") + resultId = ident("result") + + result = quote do: + proc `procName`(`streamId`: ProtobufStream): `mtype` = + `resultId` = `newproc`() + while not atEnd(stream): + let + `tagId` = readTag(`streamId`) + `wiretypeId` = getTagWireType(`tagId`) + case getTagFieldNumber(`tagId`) + else: + skipField(`streamId`, `wiretypeId`) + + let caseNode = body(result)[1][1][1] + + # TODO: check wiretypes and fail if it doesn't match + for field in fields(desc): + let + number = getFieldNumber(field) + reader = ident("read" & getFieldTypeAsString(field)) + setproc = + if isRepeated(field): + ident("add" & capitalizeAscii(getFieldName(field))) + else: + ident("set" & capitalizeAscii(getFieldName(field))) + if isRepeated(field): + if isNumeric(getFieldType(field)): + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + if `wiretypeId` == WireType.LengthDelimited: + let + size = readVarint(stream) + start = getPosition(stream).uint64 + var consumed = 0'u64 + while consumed < size: + `setproc`(`resultId`, `reader`(stream)) + consumed = getPosition(stream).uint64 - start + if consumed != size: + raise newException(Exception, "packed field size mismatch") + else: + `setproc`(`resultId`, `reader`(stream)) + )) + elif isMessage(field): + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + let size = readVarint(stream) + let data = readStr(stream, int(size)) + let stream2 = newProtobufStream(newStringStream(data)) + `setproc`(`resultId`, `reader`(stream2)) + )) + else: + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + `setproc`(`resultId`, `reader`(stream)) + )) + else: + if isMessage(field): + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + let size = readVarint(stream) + let data = readStr(stream, int(size)) + let stream2 = newProtobufStream(newStringStream(data)) + `setproc`(`resultId`, `reader`(stream2)) + )) + else: + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + `setproc`(`resultId`, `reader`(stream)) + )) + +proc generateSizeOfMessageProc(desc: NimNode): NimNode = + let + name = getMessageName(desc) + body = newStmtList() + messageId = ident("message") + resultId = ident("result") + procName = postfix(ident("sizeOf" & getMessageName(desc)), "*") + mtype = ident(getMessageName(desc)) + + result = quote do: + proc `procName`(`messageId`: `mtype`): uint64 = + `resultId` = 0 + + let procBody = body(result) + + for field in fields(desc): + let + hasproc = ident(fieldProcName("has", field)) + sizeofproc = ident("sizeOf" & getFieldTypeAsString(field)) + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) + number = getFieldNumber(field) + wiretype = ident(wiretype(field)) + + # TODO: packed + if isRepeated(field): + if isPacked(field): + procBody.add quote do: + if `hasproc`(`messageId`): + let + tagSize = sizeOfUint32(uint32(makeTag(`number`, WireType.LengthDelimited))) + dataSize = packedFieldSize(`fname`, `wiretype`) + sizeOfSize = sizeOfUint64(dataSize) + `resultId` = tagSize + dataSize + sizeOfSize + else: + procBody.add quote do: + for value in `fname`: + let + sizeOfField = `sizeofproc`(value) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + + sizeOfField + + sizeOfUint64(sizeOfField) + + tagSize + else: + let sizeOfFieldId = ident("sizeOfField") + + procBody.add quote do: + if `hasproc`(`messageId`): + let + `sizeOfFieldId` = `sizeofproc`(`fname`) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + sizeOfField + tagSize + + if isMessage(field): + # For messages we need to include the size of the encoded size + let asgn = procBody[^1][0][1][1] + asgn[1] = infix(asgn[1], "+", newCall(ident("sizeOfUint64"), + sizeOfFieldId)) + +proc generateSerializeProc(desc: NimNode): NimNode = + let + mtype = ident(getMessageName(desc)) + procName = postfix(ident("serialize"), "*") + writer = ident("write" & getMessageName(desc)) + resultId = ident("result") + + result = quote do: + proc `procName`(message: `mtype`): string = + let + ss = newStringStream() + pbs = newProtobufStream(ss) + `writer`(pbs, message) + `resultId` = ss.data + +proc generateDeserializeProc(desc: NimNode): NimNode = + let + mtype = ident(getMessageName(desc)) + procName = postfix(ident("new" & getMessageName(desc)), "*") + reader = ident("read" & getMessageName(desc)) + resultId = ident("result") + + result = quote do: + proc `procName`(data: string): `mtype` = + let + ss = newStringStream(data) + pbs = newProtobufStream(ss) + `resultId` = `reader`(pbs) + +macro generateMessageProcs*(x: typed): typed = + let + desc = getImpl(symbol(x)) + + result = newStmtList( + generateNewMessageProc(desc), + ) + + for field in fields(desc): + add(result, generateClearFieldProc(desc, field)) + add(result, generateHasFieldProc(desc, field)) + add(result, generateSetFieldProc(desc, field)) + + if isRepeated(field): + add(result, generateAddToFieldProc(desc, field)) + + add(result, generateWriteMessageProc(desc)) + add(result, generateReadMessageProc(desc)) + add(result, generateSizeOfMessageProc(desc)) + add(result, generateSerializeProc(desc)) + add(result, generateDeserializeProc(desc)) + + when defined(debug): + hint(repr(result)) + +macro generateEnumType*(x: typed): typed = + let + impl = getImpl(symbol(x)) + name = $findColonExpr(impl, "name")[1] + values = findColonExpr(impl, "values")[1] + + let enumTy = nnkEnumTy.newTree(newEmptyNode()) + + for valueNode in values: + let + name = $findColonExpr(valueNode, "name")[1] + number = findColonExpr(valueNode, "number")[1] + + add(enumTy, nnkEnumFieldDef.newTree(ident(name), number)) + + result = newStmtList(nnkTypeSection.newTree( + nnkTypeDef.newTree( + nnkPragmaExpr.newTree( + postfix(ident(name), "*"), + nnkPragma.newTree(ident("pure")) + ), + newEmptyNode(), + enumTy + ) + )) + + when defined(debug): + hint(repr(result)) + +macro generateEnumProcs*(x: typed): typed = + let + impl = getImpl(symbol(x)) + name = $findColonExpr(impl, "name")[1] + nameId = ident(name) + values = findColonExpr(impl, "values")[1] + readProc = postfix(ident("read" & name), "*") + writeProc = postfix(ident("write" & name), "*") + sizeProc = postfix(ident("sizeOf" & name), "*") + resultId = ident("result") + + result = newStmtList() + + add(result, quote do: + proc `readProc`(stream: ProtobufStream): `nameId` = + `resultId` = `nameId`(readUInt32(stream)) + + proc `writeProc`(stream: ProtobufStream, value: `nameId`) = + writeEnum(stream, value) + + proc `sizeProc`(value: `nameId`): uint64 = + `resultId` = sizeOfUInt32(uint32(value)) + ) + + when defined(debug): + hint(repr(result)) diff --git a/generator/plugin_pb.nim b/generator/plugin_pb.nim index 2f2eb1b..7ddd381 100644 --- a/generator/plugin_pb.nim +++ b/generator/plugin_pb.nim @@ -1,6 +1,6 @@ import intsets -import protobuf/gen +import gen import protobuf/protobuf import descriptor_pb diff --git a/generator/protoc_gen_nim.nim b/generator/protoc_gen_nim.nim index 3680de7..71b22f2 100644 --- a/generator/protoc_gen_nim.nim +++ b/generator/protoc_gen_nim.nim @@ -11,7 +11,8 @@ import descriptor_pb import plugin_pb import protobuf/protobuf -import protobuf/gen + +import gen type Names = distinct seq[string] |
