From 3b20e280d04708d505ab5770e8ea09aa6eab5aaf Mon Sep 17 00:00:00 2001 From: Oskari Timperi Date: Tue, 3 Apr 2018 21:14:59 +0300 Subject: Move src/protobuf/gen.nim to generator/gen.nim --- generator/descriptor_pb.nim | 2 +- generator/gen.nim | 655 +++++++++++++++++++++++++++++++++++++++++++ generator/plugin_pb.nim | 2 +- generator/protoc_gen_nim.nim | 3 +- src/protobuf/gen.nim | 655 ------------------------------------------- 5 files changed, 659 insertions(+), 658 deletions(-) create mode 100644 generator/gen.nim delete mode 100644 src/protobuf/gen.nim diff --git a/generator/descriptor_pb.nim b/generator/descriptor_pb.nim index 433a720..8eeecb6 100644 --- a/generator/descriptor_pb.nim +++ b/generator/descriptor_pb.nim @@ -1,6 +1,6 @@ import intsets -import protobuf/gen +import gen import protobuf/protobuf const diff --git a/generator/gen.nim b/generator/gen.nim new file mode 100644 index 0000000..c51b746 --- /dev/null +++ b/generator/gen.nim @@ -0,0 +1,655 @@ +import macros +import strutils + +import protobuf/protobuf + +type + MessageDesc* = object + name*: string + fields*: seq[FieldDesc] + oneofs*: seq[string] + + FieldLabel* {.pure.} = enum + Optional = 1 + Required + Repeated + + FieldDesc* = object + name*: string + number*: int + ftype*: FieldType + label*: FieldLabel + typeName*: string + packed*: bool + oneofIdx*: int + + EnumDesc* = object + name*: string + values*: seq[EnumValueDesc] + + EnumValueDesc* = object + name*: string + number*: int + +proc findColonExpr(parent: NimNode, s: string): NimNode = + for child in parent: + if child.kind != nnkExprColonExpr: + continue + + if $child[0] == s: + return child + +proc getMessageName(desc: NimNode): string = + let node = findColonExpr(desc, "name") + result = $node[1] + +iterator fields(desc: NimNode): NimNode = + let node = findColonExpr(desc, "fields") + for field in node[1]: + yield field + +proc isRepeated(field: NimNode): bool = + let node = findColonExpr(field, "label") + let value = FieldLabel(node[1].intVal) + result = value == FieldLabel.Repeated + +proc isPacked(field: NimNode): bool = + let node = findColonExpr(field, "packed") + result = bool(node[1].intVal) + +proc getFieldType(field: NimNode): FieldType = + let node = findColonExpr(field, "ftype") + result = FieldType(node[1].intVal) + +proc isMessage(field: NimNode): bool = + result = getFieldType(field) == FieldType.Message + +proc isEnum(field: NimNode): bool = + result = getFieldType(field) == FieldType.Enum + +proc getFieldTypeName(field: NimNode): string = + let node = findColonExpr(field, "typeName") + result = $node[1] + +proc getFieldTypeAsString(field: NimNode): string = + if isMessage(field) or isEnum(field): + result = getFieldTypeName(field) + else: + case getFieldType(field) + of FieldType.Double: result = "float64" + of FieldType.Float: result = "float32" + of FieldType.Int64: result = "int64" + of FieldType.UInt64: result = "uint64" + of FieldType.Int32: result = "int32" + of FieldType.Fixed64: result = "uint64" + of FieldType.Fixed32: result = "uint32" + of FieldType.Bool: result = "bool" + of FieldType.String: result = "string" + of FieldType.Bytes: result = "bytes" + of FieldType.UInt32: result = "uint32" + of FieldType.SFixed32: result = "int32" + of FieldType.SFixed64: result = "int64" + of FieldType.SInt32: result = "int32" + of FieldType.SInt64: result = "int64" + else: result = "AYBABTU" + +proc getFullFieldType(field: NimNode): NimNode = + result = ident(getFieldTypeAsString(field)) + if isRepeated(field): + result = nnkBracketExpr.newTree(ident("seq"), result) + +proc getFieldName(field: NimNode): string = + let node = findColonExpr(field, "name") + result = $node[1] + +proc getFieldNumber(field: NimNode): int = + result = int(findColonExpr(field, "number")[1].intVal) + +proc defaultValue(field: NimNode): NimNode = + # TODO: check if there is a default value specified for the field + + if isRepeated(field): + return nnkPrefix.newTree(newIdentNode("@"), nnkBracket.newTree()) + + case getFieldType(field) + of FieldType.Double: result = newLit(0.0'f64) + of FieldType.Float: result = newLit(0.0'f32) + of FieldType.Int64: result = newLit(0'i64) + of FieldType.UInt64: result = newLit(0'u64) + of FieldType.Int32: result = newLit(0'i32) + of FieldType.Fixed64: result = newLit(0'u64) + of FieldType.Fixed32: result = newLit(0'u32) + of FieldType.Bool: result = newLit(false) + of FieldType.String: result = newLit("") + of FieldType.Group: result = newLit("NOTIMPLEMENTED") + of FieldType.Message: result = newCall(ident("new" & getFieldTypeAsString(field))) + of FieldType.Bytes: result = newCall(ident("bytes"), newLit("")) + of FieldType.UInt32: result = newLit(0'u32) + of FieldType.Enum: + let + descId = ident(getFieldTypeAsString(field) & "Desc") + nameId = ident(getFieldTypeAsString(field)) + result = quote do: + `nameId`(`descId`.values[0].number) + of FieldType.SFixed32: result = newLit(0'u32) + of FieldType.SFixed64: result = newLit(0'u32) + of FieldType.SInt32: result = newLit(0) + of FieldType.SInt64: result = newLit(0) + +proc wiretype(field: NimNode): WireType = + result = wiretype(getFieldType(field)) + +# TODO: maybe not the best name for this +proc getFieldNameAST(objname: NimNode, field: NimNode, oneof: string): NimNode = + result = + if oneof != "": + newDotExpr(newDotExpr(objname, ident(oneof)), ident(getFieldName(field))) + else: + newDotExpr(objname, ident(getFieldName(field))) + +proc fieldInitializer(objname: NimNode, field: NimNode, oneof: string): NimNode = + result = nnkAsgn.newTree( + getFieldNameAST(objname, field, oneof), + defaultValue(field) + ) + +proc oneofIndex(field: NimNode): int = + let node = findColonExpr(field, "oneofIdx") + if node == nil: + result = -1 + else: + result = int(node[1].intVal) + +proc oneofName(message, field: NimNode): string = + let index = oneofIndex(field) + + if index == -1: + return "" + + let oneofs = findColonExpr(message, "oneofs")[1] + + result = $oneofs[index] + +iterator oneofFields(message: NimNode, index: int): NimNode = + if index != -1: + for field in fields(message): + if oneofIndex(field) == index: + yield field + +proc generateOneofFields*(desc: NimNode, typeSection: NimNode) = + let + oneofs = findColonExpr(desc, "oneofs")[1] + messageName = getMessageName(desc) + + for index, oneof in oneofs: + let reclist = nnkRecList.newTree() + + for field in oneofFields(desc, index): + let ftype = getFullFieldType(field) + let name = ident(getFieldName(field)) + + add(reclist, newIdentDefs(postfix(name, "*"), ftype)) + + let typedef = nnkTypeDef.newTree( + nnkPragmaExpr.newTree( + postfix(ident(messageName & $oneof), "*"), + nnkPragma.newTree( + ident("union") + ) + ), + newEmptyNode(), + nnkObjectTy.newTree( + newEmptyNode(), + newEmptyNode(), + reclist + ) + ) + + add(typeSection, typedef) + +macro generateMessageType*(desc: typed): typed = + let + impl = getImpl(symbol(desc)) + typeSection = nnkTypeSection.newTree() + typedef = nnkTypeDef.newTree() + reclist = nnkRecList.newTree() + oneofs = findColonExpr(impl, "oneofs")[1] + + let name = getMessageName(impl) + + let typedefRef = nnkTypeDef.newTree(postfix(newIdentNode(name), "*"), newEmptyNode(), + nnkRefTy.newTree(newIdentNode(name & "Obj"))) + add(typeSection, typedefRef) + + add(typeSection, typedef) + + add(typedef, postfix(ident(name & "Obj"), "*")) + add(typedef, newEmptyNode()) + add(typedef, nnkObjectTy.newTree(newEmptyNode(), newEmptyNode(), reclist)) + + for field in fields(impl): + let ftype = getFullFieldType(field) + let name = ident(getFieldName(field)) + if oneofIndex(field) == -1: + add(reclist, newIdentDefs(postfix(name, "*"), ftype)) + + for oneof in oneofs: + add(reclist, newIdentDefs(postfix(ident($oneof), "*"), + ident(name & $oneof))) + + add(reclist, nnkIdentDefs.newTree( + ident("hasField"), ident("IntSet"), newEmptyNode())) + + generateOneofFields(impl, typeSection) + + result = newStmtList() + add(result, typeSection) + + when defined(debug): + hint(repr(result)) + +proc generateNewMessageProc(desc: NimNode): NimNode = + let + body = newStmtList( + newCall(ident("new"), ident("result")) + ) + resultId = ident("result") + + for field in fields(desc): + let oneofName = oneofName(desc, field) + add(body, fieldInitializer(resultId, field, oneofName)) + + add(body, newAssignment(newDotExpr(resultId, ident("hasField")), + newCall(ident("initIntSet")))) + + result = newProc(postfix(ident("new" & getMessageName(desc)), "*"), + @[ident(getMessageName(desc))], + body) + +proc fieldProcName(prefix: string, field: NimNode): string = + result = prefix & capitalizeAscii(getFieldName(field)) + +proc fieldProcIdent(prefix: string, field: NimNode): NimNode = + result = postfix(ident(fieldProcName(prefix, field)), "*") + +proc generateClearFieldProc(desc, field: NimNode): NimNode = + let + messageId = ident("message") + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) + defvalue = defaultValue(field) + hasField = newDotExpr(messageId, ident("hasField")) + number = getFieldNumber(field) + procName = fieldProcIdent("clear", field) + mtype = ident(getMessageName(desc)) + + result = quote do: + proc `procName`(`messageId`: `mtype`) = + `fname` = `defvalue` + excl(`hasfield`, `number`) + + # When clearing a field that is contained in a oneof, we should also clear + # the other fields. + for sibling in oneofFields(desc, oneofIndex(field)): + if sibling == field: + continue + let + number = getFieldNumber(sibling) + exclNode = quote do: + excl(`hasField`, `number`) + add(body(result), exclNode) + +proc generateHasFieldProc(desc, field: NimNode): NimNode = + let + messageId = ident("message") + hasField = newDotExpr(messageId, ident("hasField")) + number = getFieldNumber(field) + mtype = ident(getMessageName(desc)) + procName = fieldProcIdent("has", field) + + result = quote do: + proc `procName`(`messageId`: `mtype`): bool = + contains(`hasfield`, `number`) + +proc generateSetFieldProc(desc, field: NimNode): NimNode = + let + messageId = ident("message") + hasField = newDotExpr(messageId, ident("hasField")) + number = getFieldNumber(field) + valueId = ident("value") + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) + procName = fieldProcIdent("set", field) + mtype = ident(getMessageName(desc)) + ftype = getFullFieldType(field) + + result = quote do: + proc `procName`(`messageId`: `mtype`, `valueId`: `ftype`) = + `fname` = `valueId` + incl(`hasfield`, `number`) + + # When setting a field that is in a oneof, we need to unset the other fields + for sibling in oneofFields(desc, oneofIndex(field)): + if sibling == field: + continue + let + number = getFieldNumber(sibling) + exclNode = quote do: + excl(`hasField`, `number`) + add(body(result), exclNode) + +proc generateAddToFieldProc(desc, field: NimNode): NimNode = + let + procName = fieldProcIdent("add", field) + messageId = ident("message") + mtype = ident(getMessageName(desc)) + valueId = ident("value") + ftype = ident(getFieldTypeAsString(field)) + hasField = newDotExpr(messageId, ident("hasField")) + number = getFieldNumber(field) + fname = newDotExpr(messageId, ident(getFieldName(field))) + + result = quote do: + proc `procName`(`messageId`: `mtype`, `valueId`: `ftype`) = + add(`fname`, `valueId`) + incl(`hasfield`, `number`) + +proc ident(wt: WireType): NimNode = + result = newDotExpr(ident("WireType"), ident($wt)) + +proc genWriteField(message, field: NimNode): NimNode = + result = newStmtList() + + let + number = getFieldNumber(field) + writer = ident("write" & getFieldTypeAsString(field)) + messageId = ident("message") + fname = getFieldNameAST(messageId, field, oneofName(message, field)) + wiretype = ident(wiretype(field)) + sizeproc = ident("sizeOf" & getFieldTypeAsString(field)) + hasproc = ident(fieldProcName("has", field)) + + if not isRepeated(field): + result.add quote do: + if `hasproc`(message): + writeTag(stream, `number`, `wiretype`) + `writer`(stream, `fname`) + if isMessage(field): + insert(result[0][0][1], 1, quote do: + writeVarint(stream, `sizeproc`(`fname`)) + ) + else: + let valueId = ident("value") + if isPacked(field): + result.add quote do: + writeTag(stream, `number`, WireType.LengthDelimited) + writeVarInt(stream, packedFieldSize(`fname`, `wiretype`)) + for `valueId` in `fname`: + `writer`(stream, `valueId`) + else: + result.add quote do: + for `valueId` in `fname`: + writeTag(stream, `number`, `wiretype`) + `writer`(stream, `valueId`) + if isMessage(field): + insert(result[^1][^1], 1, quote do: + writeVarint(stream, `sizeproc`(`valueId`)) + ) + +proc generateWriteMessageProc(desc: NimNode): NimNode = + let + messageId = ident("message") + mtype = ident(getMessageName(desc)) + procName = postfix(ident("write" & getMessageName(desc)), "*") + body = newStmtList() + stream = ident("stream") + sizeproc = postfix(ident("sizeOf" & getMessageName(desc)), "*") + + for field in fields(desc): + add(body, genWriteField(desc, field)) + + result = quote do: + proc `sizeproc`(`messageId`: `mtype`): uint64 + + proc `procName`(`stream`: ProtobufStream, `messageId`: `mtype`) = + `body` + +proc generateReadMessageProc(desc: NimNode): NimNode = + let + procName = postfix(ident("read" & getMessageName(desc)), "*") + newproc = ident("new" & getMessageName(desc)) + streamId = ident("stream") + mtype = ident(getMessageName(desc)) + tagId = ident("tag") + wiretypeId = ident("wiretype") + resultId = ident("result") + + result = quote do: + proc `procName`(`streamId`: ProtobufStream): `mtype` = + `resultId` = `newproc`() + while not atEnd(stream): + let + `tagId` = readTag(`streamId`) + `wiretypeId` = getTagWireType(`tagId`) + case getTagFieldNumber(`tagId`) + else: + skipField(`streamId`, `wiretypeId`) + + let caseNode = body(result)[1][1][1] + + # TODO: check wiretypes and fail if it doesn't match + for field in fields(desc): + let + number = getFieldNumber(field) + reader = ident("read" & getFieldTypeAsString(field)) + setproc = + if isRepeated(field): + ident("add" & capitalizeAscii(getFieldName(field))) + else: + ident("set" & capitalizeAscii(getFieldName(field))) + if isRepeated(field): + if isNumeric(getFieldType(field)): + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + if `wiretypeId` == WireType.LengthDelimited: + let + size = readVarint(stream) + start = getPosition(stream).uint64 + var consumed = 0'u64 + while consumed < size: + `setproc`(`resultId`, `reader`(stream)) + consumed = getPosition(stream).uint64 - start + if consumed != size: + raise newException(Exception, "packed field size mismatch") + else: + `setproc`(`resultId`, `reader`(stream)) + )) + elif isMessage(field): + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + let size = readVarint(stream) + let data = readStr(stream, int(size)) + let stream2 = newProtobufStream(newStringStream(data)) + `setproc`(`resultId`, `reader`(stream2)) + )) + else: + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + `setproc`(`resultId`, `reader`(stream)) + )) + else: + if isMessage(field): + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + let size = readVarint(stream) + let data = readStr(stream, int(size)) + let stream2 = newProtobufStream(newStringStream(data)) + `setproc`(`resultId`, `reader`(stream2)) + )) + else: + insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: + `setproc`(`resultId`, `reader`(stream)) + )) + +proc generateSizeOfMessageProc(desc: NimNode): NimNode = + let + name = getMessageName(desc) + body = newStmtList() + messageId = ident("message") + resultId = ident("result") + procName = postfix(ident("sizeOf" & getMessageName(desc)), "*") + mtype = ident(getMessageName(desc)) + + result = quote do: + proc `procName`(`messageId`: `mtype`): uint64 = + `resultId` = 0 + + let procBody = body(result) + + for field in fields(desc): + let + hasproc = ident(fieldProcName("has", field)) + sizeofproc = ident("sizeOf" & getFieldTypeAsString(field)) + fname = getFieldNameAST(messageId, field, oneofName(desc, field)) + number = getFieldNumber(field) + wiretype = ident(wiretype(field)) + + # TODO: packed + if isRepeated(field): + if isPacked(field): + procBody.add quote do: + if `hasproc`(`messageId`): + let + tagSize = sizeOfUint32(uint32(makeTag(`number`, WireType.LengthDelimited))) + dataSize = packedFieldSize(`fname`, `wiretype`) + sizeOfSize = sizeOfUint64(dataSize) + `resultId` = tagSize + dataSize + sizeOfSize + else: + procBody.add quote do: + for value in `fname`: + let + sizeOfField = `sizeofproc`(value) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + + sizeOfField + + sizeOfUint64(sizeOfField) + + tagSize + else: + let sizeOfFieldId = ident("sizeOfField") + + procBody.add quote do: + if `hasproc`(`messageId`): + let + `sizeOfFieldId` = `sizeofproc`(`fname`) + tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) + `resultId` = `resultId` + sizeOfField + tagSize + + if isMessage(field): + # For messages we need to include the size of the encoded size + let asgn = procBody[^1][0][1][1] + asgn[1] = infix(asgn[1], "+", newCall(ident("sizeOfUint64"), + sizeOfFieldId)) + +proc generateSerializeProc(desc: NimNode): NimNode = + let + mtype = ident(getMessageName(desc)) + procName = postfix(ident("serialize"), "*") + writer = ident("write" & getMessageName(desc)) + resultId = ident("result") + + result = quote do: + proc `procName`(message: `mtype`): string = + let + ss = newStringStream() + pbs = newProtobufStream(ss) + `writer`(pbs, message) + `resultId` = ss.data + +proc generateDeserializeProc(desc: NimNode): NimNode = + let + mtype = ident(getMessageName(desc)) + procName = postfix(ident("new" & getMessageName(desc)), "*") + reader = ident("read" & getMessageName(desc)) + resultId = ident("result") + + result = quote do: + proc `procName`(data: string): `mtype` = + let + ss = newStringStream(data) + pbs = newProtobufStream(ss) + `resultId` = `reader`(pbs) + +macro generateMessageProcs*(x: typed): typed = + let + desc = getImpl(symbol(x)) + + result = newStmtList( + generateNewMessageProc(desc), + ) + + for field in fields(desc): + add(result, generateClearFieldProc(desc, field)) + add(result, generateHasFieldProc(desc, field)) + add(result, generateSetFieldProc(desc, field)) + + if isRepeated(field): + add(result, generateAddToFieldProc(desc, field)) + + add(result, generateWriteMessageProc(desc)) + add(result, generateReadMessageProc(desc)) + add(result, generateSizeOfMessageProc(desc)) + add(result, generateSerializeProc(desc)) + add(result, generateDeserializeProc(desc)) + + when defined(debug): + hint(repr(result)) + +macro generateEnumType*(x: typed): typed = + let + impl = getImpl(symbol(x)) + name = $findColonExpr(impl, "name")[1] + values = findColonExpr(impl, "values")[1] + + let enumTy = nnkEnumTy.newTree(newEmptyNode()) + + for valueNode in values: + let + name = $findColonExpr(valueNode, "name")[1] + number = findColonExpr(valueNode, "number")[1] + + add(enumTy, nnkEnumFieldDef.newTree(ident(name), number)) + + result = newStmtList(nnkTypeSection.newTree( + nnkTypeDef.newTree( + nnkPragmaExpr.newTree( + postfix(ident(name), "*"), + nnkPragma.newTree(ident("pure")) + ), + newEmptyNode(), + enumTy + ) + )) + + when defined(debug): + hint(repr(result)) + +macro generateEnumProcs*(x: typed): typed = + let + impl = getImpl(symbol(x)) + name = $findColonExpr(impl, "name")[1] + nameId = ident(name) + values = findColonExpr(impl, "values")[1] + readProc = postfix(ident("read" & name), "*") + writeProc = postfix(ident("write" & name), "*") + sizeProc = postfix(ident("sizeOf" & name), "*") + resultId = ident("result") + + result = newStmtList() + + add(result, quote do: + proc `readProc`(stream: ProtobufStream): `nameId` = + `resultId` = `nameId`(readUInt32(stream)) + + proc `writeProc`(stream: ProtobufStream, value: `nameId`) = + writeEnum(stream, value) + + proc `sizeProc`(value: `nameId`): uint64 = + `resultId` = sizeOfUInt32(uint32(value)) + ) + + when defined(debug): + hint(repr(result)) diff --git a/generator/plugin_pb.nim b/generator/plugin_pb.nim index 2f2eb1b..7ddd381 100644 --- a/generator/plugin_pb.nim +++ b/generator/plugin_pb.nim @@ -1,6 +1,6 @@ import intsets -import protobuf/gen +import gen import protobuf/protobuf import descriptor_pb diff --git a/generator/protoc_gen_nim.nim b/generator/protoc_gen_nim.nim index 3680de7..71b22f2 100644 --- a/generator/protoc_gen_nim.nim +++ b/generator/protoc_gen_nim.nim @@ -11,7 +11,8 @@ import descriptor_pb import plugin_pb import protobuf/protobuf -import protobuf/gen + +import gen type Names = distinct seq[string] diff --git a/src/protobuf/gen.nim b/src/protobuf/gen.nim deleted file mode 100644 index 039b6f7..0000000 --- a/src/protobuf/gen.nim +++ /dev/null @@ -1,655 +0,0 @@ -import macros -import strutils - -import protobuf - -type - MessageDesc* = object - name*: string - fields*: seq[FieldDesc] - oneofs*: seq[string] - - FieldLabel* {.pure.} = enum - Optional = 1 - Required - Repeated - - FieldDesc* = object - name*: string - number*: int - ftype*: FieldType - label*: FieldLabel - typeName*: string - packed*: bool - oneofIdx*: int - - EnumDesc* = object - name*: string - values*: seq[EnumValueDesc] - - EnumValueDesc* = object - name*: string - number*: int - -proc findColonExpr(parent: NimNode, s: string): NimNode = - for child in parent: - if child.kind != nnkExprColonExpr: - continue - - if $child[0] == s: - return child - -proc getMessageName(desc: NimNode): string = - let node = findColonExpr(desc, "name") - result = $node[1] - -iterator fields(desc: NimNode): NimNode = - let node = findColonExpr(desc, "fields") - for field in node[1]: - yield field - -proc isRepeated(field: NimNode): bool = - let node = findColonExpr(field, "label") - let value = FieldLabel(node[1].intVal) - result = value == FieldLabel.Repeated - -proc isPacked(field: NimNode): bool = - let node = findColonExpr(field, "packed") - result = bool(node[1].intVal) - -proc getFieldType(field: NimNode): FieldType = - let node = findColonExpr(field, "ftype") - result = FieldType(node[1].intVal) - -proc isMessage(field: NimNode): bool = - result = getFieldType(field) == FieldType.Message - -proc isEnum(field: NimNode): bool = - result = getFieldType(field) == FieldType.Enum - -proc getFieldTypeName(field: NimNode): string = - let node = findColonExpr(field, "typeName") - result = $node[1] - -proc getFieldTypeAsString(field: NimNode): string = - if isMessage(field) or isEnum(field): - result = getFieldTypeName(field) - else: - case getFieldType(field) - of FieldType.Double: result = "float64" - of FieldType.Float: result = "float32" - of FieldType.Int64: result = "int64" - of FieldType.UInt64: result = "uint64" - of FieldType.Int32: result = "int32" - of FieldType.Fixed64: result = "uint64" - of FieldType.Fixed32: result = "uint32" - of FieldType.Bool: result = "bool" - of FieldType.String: result = "string" - of FieldType.Bytes: result = "bytes" - of FieldType.UInt32: result = "uint32" - of FieldType.SFixed32: result = "int32" - of FieldType.SFixed64: result = "int64" - of FieldType.SInt32: result = "int32" - of FieldType.SInt64: result = "int64" - else: result = "AYBABTU" - -proc getFullFieldType(field: NimNode): NimNode = - result = ident(getFieldTypeAsString(field)) - if isRepeated(field): - result = nnkBracketExpr.newTree(ident("seq"), result) - -proc getFieldName(field: NimNode): string = - let node = findColonExpr(field, "name") - result = $node[1] - -proc getFieldNumber(field: NimNode): int = - result = int(findColonExpr(field, "number")[1].intVal) - -proc defaultValue(field: NimNode): NimNode = - # TODO: check if there is a default value specified for the field - - if isRepeated(field): - return nnkPrefix.newTree(newIdentNode("@"), nnkBracket.newTree()) - - case getFieldType(field) - of FieldType.Double: result = newLit(0.0'f64) - of FieldType.Float: result = newLit(0.0'f32) - of FieldType.Int64: result = newLit(0'i64) - of FieldType.UInt64: result = newLit(0'u64) - of FieldType.Int32: result = newLit(0'i32) - of FieldType.Fixed64: result = newLit(0'u64) - of FieldType.Fixed32: result = newLit(0'u32) - of FieldType.Bool: result = newLit(false) - of FieldType.String: result = newLit("") - of FieldType.Group: result = newLit("NOTIMPLEMENTED") - of FieldType.Message: result = newCall(ident("new" & getFieldTypeAsString(field))) - of FieldType.Bytes: result = newCall(ident("bytes"), newLit("")) - of FieldType.UInt32: result = newLit(0'u32) - of FieldType.Enum: - let - descId = ident(getFieldTypeAsString(field) & "Desc") - nameId = ident(getFieldTypeAsString(field)) - result = quote do: - `nameId`(`descId`.values[0].number) - of FieldType.SFixed32: result = newLit(0'u32) - of FieldType.SFixed64: result = newLit(0'u32) - of FieldType.SInt32: result = newLit(0) - of FieldType.SInt64: result = newLit(0) - -proc wiretype(field: NimNode): WireType = - result = wiretype(getFieldType(field)) - -# TODO: maybe not the best name for this -proc getFieldNameAST(objname: NimNode, field: NimNode, oneof: string): NimNode = - result = - if oneof != "": - newDotExpr(newDotExpr(objname, ident(oneof)), ident(getFieldName(field))) - else: - newDotExpr(objname, ident(getFieldName(field))) - -proc fieldInitializer(objname: NimNode, field: NimNode, oneof: string): NimNode = - result = nnkAsgn.newTree( - getFieldNameAST(objname, field, oneof), - defaultValue(field) - ) - -proc oneofIndex(field: NimNode): int = - let node = findColonExpr(field, "oneofIdx") - if node == nil: - result = -1 - else: - result = int(node[1].intVal) - -proc oneofName(message, field: NimNode): string = - let index = oneofIndex(field) - - if index == -1: - return "" - - let oneofs = findColonExpr(message, "oneofs")[1] - - result = $oneofs[index] - -iterator oneofFields(message: NimNode, index: int): NimNode = - if index != -1: - for field in fields(message): - if oneofIndex(field) == index: - yield field - -proc generateOneofFields*(desc: NimNode, typeSection: NimNode) = - let - oneofs = findColonExpr(desc, "oneofs")[1] - messageName = getMessageName(desc) - - for index, oneof in oneofs: - let reclist = nnkRecList.newTree() - - for field in oneofFields(desc, index): - let ftype = getFullFieldType(field) - let name = ident(getFieldName(field)) - - add(reclist, newIdentDefs(postfix(name, "*"), ftype)) - - let typedef = nnkTypeDef.newTree( - nnkPragmaExpr.newTree( - postfix(ident(messageName & $oneof), "*"), - nnkPragma.newTree( - ident("union") - ) - ), - newEmptyNode(), - nnkObjectTy.newTree( - newEmptyNode(), - newEmptyNode(), - reclist - ) - ) - - add(typeSection, typedef) - -macro generateMessageType*(desc: typed): typed = - let - impl = getImpl(symbol(desc)) - typeSection = nnkTypeSection.newTree() - typedef = nnkTypeDef.newTree() - reclist = nnkRecList.newTree() - oneofs = findColonExpr(impl, "oneofs")[1] - - let name = getMessageName(impl) - - let typedefRef = nnkTypeDef.newTree(postfix(newIdentNode(name), "*"), newEmptyNode(), - nnkRefTy.newTree(newIdentNode(name & "Obj"))) - add(typeSection, typedefRef) - - add(typeSection, typedef) - - add(typedef, postfix(ident(name & "Obj"), "*")) - add(typedef, newEmptyNode()) - add(typedef, nnkObjectTy.newTree(newEmptyNode(), newEmptyNode(), reclist)) - - for field in fields(impl): - let ftype = getFullFieldType(field) - let name = ident(getFieldName(field)) - if oneofIndex(field) == -1: - add(reclist, newIdentDefs(postfix(name, "*"), ftype)) - - for oneof in oneofs: - add(reclist, newIdentDefs(postfix(ident($oneof), "*"), - ident(name & $oneof))) - - add(reclist, nnkIdentDefs.newTree( - ident("hasField"), ident("IntSet"), newEmptyNode())) - - generateOneofFields(impl, typeSection) - - result = newStmtList() - add(result, typeSection) - - when defined(debug): - hint(repr(result)) - -proc generateNewMessageProc(desc: NimNode): NimNode = - let - body = newStmtList( - newCall(ident("new"), ident("result")) - ) - resultId = ident("result") - - for field in fields(desc): - let oneofName = oneofName(desc, field) - add(body, fieldInitializer(resultId, field, oneofName)) - - add(body, newAssignment(newDotExpr(resultId, ident("hasField")), - newCall(ident("initIntSet")))) - - result = newProc(postfix(ident("new" & getMessageName(desc)), "*"), - @[ident(getMessageName(desc))], - body) - -proc fieldProcName(prefix: string, field: NimNode): string = - result = prefix & capitalizeAscii(getFieldName(field)) - -proc fieldProcIdent(prefix: string, field: NimNode): NimNode = - result = postfix(ident(fieldProcName(prefix, field)), "*") - -proc generateClearFieldProc(desc, field: NimNode): NimNode = - let - messageId = ident("message") - fname = getFieldNameAST(messageId, field, oneofName(desc, field)) - defvalue = defaultValue(field) - hasField = newDotExpr(messageId, ident("hasField")) - number = getFieldNumber(field) - procName = fieldProcIdent("clear", field) - mtype = ident(getMessageName(desc)) - - result = quote do: - proc `procName`(`messageId`: `mtype`) = - `fname` = `defvalue` - excl(`hasfield`, `number`) - - # When clearing a field that is contained in a oneof, we should also clear - # the other fields. - for sibling in oneofFields(desc, oneofIndex(field)): - if sibling == field: - continue - let - number = getFieldNumber(sibling) - exclNode = quote do: - excl(`hasField`, `number`) - add(body(result), exclNode) - -proc generateHasFieldProc(desc, field: NimNode): NimNode = - let - messageId = ident("message") - hasField = newDotExpr(messageId, ident("hasField")) - number = getFieldNumber(field) - mtype = ident(getMessageName(desc)) - procName = fieldProcIdent("has", field) - - result = quote do: - proc `procName`(`messageId`: `mtype`): bool = - contains(`hasfield`, `number`) - -proc generateSetFieldProc(desc, field: NimNode): NimNode = - let - messageId = ident("message") - hasField = newDotExpr(messageId, ident("hasField")) - number = getFieldNumber(field) - valueId = ident("value") - fname = getFieldNameAST(messageId, field, oneofName(desc, field)) - procName = fieldProcIdent("set", field) - mtype = ident(getMessageName(desc)) - ftype = getFullFieldType(field) - - result = quote do: - proc `procName`(`messageId`: `mtype`, `valueId`: `ftype`) = - `fname` = `valueId` - incl(`hasfield`, `number`) - - # When setting a field that is in a oneof, we need to unset the other fields - for sibling in oneofFields(desc, oneofIndex(field)): - if sibling == field: - continue - let - number = getFieldNumber(sibling) - exclNode = quote do: - excl(`hasField`, `number`) - add(body(result), exclNode) - -proc generateAddToFieldProc(desc, field: NimNode): NimNode = - let - procName = fieldProcIdent("add", field) - messageId = ident("message") - mtype = ident(getMessageName(desc)) - valueId = ident("value") - ftype = ident(getFieldTypeAsString(field)) - hasField = newDotExpr(messageId, ident("hasField")) - number = getFieldNumber(field) - fname = newDotExpr(messageId, ident(getFieldName(field))) - - result = quote do: - proc `procName`(`messageId`: `mtype`, `valueId`: `ftype`) = - add(`fname`, `valueId`) - incl(`hasfield`, `number`) - -proc ident(wt: WireType): NimNode = - result = newDotExpr(ident("WireType"), ident($wt)) - -proc genWriteField(message, field: NimNode): NimNode = - result = newStmtList() - - let - number = getFieldNumber(field) - writer = ident("write" & getFieldTypeAsString(field)) - messageId = ident("message") - fname = getFieldNameAST(messageId, field, oneofName(message, field)) - wiretype = ident(wiretype(field)) - sizeproc = ident("sizeOf" & getFieldTypeAsString(field)) - hasproc = ident(fieldProcName("has", field)) - - if not isRepeated(field): - result.add quote do: - if `hasproc`(message): - writeTag(stream, `number`, `wiretype`) - `writer`(stream, `fname`) - if isMessage(field): - insert(result[0][0][1], 1, quote do: - writeVarint(stream, `sizeproc`(`fname`)) - ) - else: - let valueId = ident("value") - if isPacked(field): - result.add quote do: - writeTag(stream, `number`, WireType.LengthDelimited) - writeVarInt(stream, packedFieldSize(`fname`, `wiretype`)) - for `valueId` in `fname`: - `writer`(stream, `valueId`) - else: - result.add quote do: - for `valueId` in `fname`: - writeTag(stream, `number`, `wiretype`) - `writer`(stream, `valueId`) - if isMessage(field): - insert(result[^1][^1], 1, quote do: - writeVarint(stream, `sizeproc`(`valueId`)) - ) - -proc generateWriteMessageProc(desc: NimNode): NimNode = - let - messageId = ident("message") - mtype = ident(getMessageName(desc)) - procName = postfix(ident("write" & getMessageName(desc)), "*") - body = newStmtList() - stream = ident("stream") - sizeproc = postfix(ident("sizeOf" & getMessageName(desc)), "*") - - for field in fields(desc): - add(body, genWriteField(desc, field)) - - result = quote do: - proc `sizeproc`(`messageId`: `mtype`): uint64 - - proc `procName`(`stream`: ProtobufStream, `messageId`: `mtype`) = - `body` - -proc generateReadMessageProc(desc: NimNode): NimNode = - let - procName = postfix(ident("read" & getMessageName(desc)), "*") - newproc = ident("new" & getMessageName(desc)) - streamId = ident("stream") - mtype = ident(getMessageName(desc)) - tagId = ident("tag") - wiretypeId = ident("wiretype") - resultId = ident("result") - - result = quote do: - proc `procName`(`streamId`: ProtobufStream): `mtype` = - `resultId` = `newproc`() - while not atEnd(stream): - let - `tagId` = readTag(`streamId`) - `wiretypeId` = getTagWireType(`tagId`) - case getTagFieldNumber(`tagId`) - else: - skipField(`streamId`, `wiretypeId`) - - let caseNode = body(result)[1][1][1] - - # TODO: check wiretypes and fail if it doesn't match - for field in fields(desc): - let - number = getFieldNumber(field) - reader = ident("read" & getFieldTypeAsString(field)) - setproc = - if isRepeated(field): - ident("add" & capitalizeAscii(getFieldName(field))) - else: - ident("set" & capitalizeAscii(getFieldName(field))) - if isRepeated(field): - if isNumeric(getFieldType(field)): - insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: - if `wiretypeId` == WireType.LengthDelimited: - let - size = readVarint(stream) - start = getPosition(stream).uint64 - var consumed = 0'u64 - while consumed < size: - `setproc`(`resultId`, `reader`(stream)) - consumed = getPosition(stream).uint64 - start - if consumed != size: - raise newException(Exception, "packed field size mismatch") - else: - `setproc`(`resultId`, `reader`(stream)) - )) - elif isMessage(field): - insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: - let size = readVarint(stream) - let data = readStr(stream, int(size)) - let stream2 = newProtobufStream(newStringStream(data)) - `setproc`(`resultId`, `reader`(stream2)) - )) - else: - insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: - `setproc`(`resultId`, `reader`(stream)) - )) - else: - if isMessage(field): - insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: - let size = readVarint(stream) - let data = readStr(stream, int(size)) - let stream2 = newProtobufStream(newStringStream(data)) - `setproc`(`resultId`, `reader`(stream2)) - )) - else: - insert(caseNode, 1, nnkOfBranch.newTree(newLit(number), quote do: - `setproc`(`resultId`, `reader`(stream)) - )) - -proc generateSizeOfMessageProc(desc: NimNode): NimNode = - let - name = getMessageName(desc) - body = newStmtList() - messageId = ident("message") - resultId = ident("result") - procName = postfix(ident("sizeOf" & getMessageName(desc)), "*") - mtype = ident(getMessageName(desc)) - - result = quote do: - proc `procName`(`messageId`: `mtype`): uint64 = - `resultId` = 0 - - let procBody = body(result) - - for field in fields(desc): - let - hasproc = ident(fieldProcName("has", field)) - sizeofproc = ident("sizeOf" & getFieldTypeAsString(field)) - fname = getFieldNameAST(messageId, field, oneofName(desc, field)) - number = getFieldNumber(field) - wiretype = ident(wiretype(field)) - - # TODO: packed - if isRepeated(field): - if isPacked(field): - procBody.add quote do: - if `hasproc`(`messageId`): - let - tagSize = sizeOfUint32(uint32(makeTag(`number`, WireType.LengthDelimited))) - dataSize = packedFieldSize(`fname`, `wiretype`) - sizeOfSize = sizeOfUint64(dataSize) - `resultId` = tagSize + dataSize + sizeOfSize - else: - procBody.add quote do: - for value in `fname`: - let - sizeOfField = `sizeofproc`(value) - tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) - `resultId` = `resultId` + - sizeOfField + - sizeOfUint64(sizeOfField) + - tagSize - else: - let sizeOfFieldId = ident("sizeOfField") - - procBody.add quote do: - if `hasproc`(`messageId`): - let - `sizeOfFieldId` = `sizeofproc`(`fname`) - tagSize = sizeOfUint32(uint32(makeTag(`number`, `wiretype`))) - `resultId` = `resultId` + sizeOfField + tagSize - - if isMessage(field): - # For messages we need to include the size of the encoded size - let asgn = procBody[^1][0][1][1] - asgn[1] = infix(asgn[1], "+", newCall(ident("sizeOfUint64"), - sizeOfFieldId)) - -proc generateSerializeProc(desc: NimNode): NimNode = - let - mtype = ident(getMessageName(desc)) - procName = postfix(ident("serialize"), "*") - writer = ident("write" & getMessageName(desc)) - resultId = ident("result") - - result = quote do: - proc `procName`(message: `mtype`): string = - let - ss = newStringStream() - pbs = newProtobufStream(ss) - `writer`(pbs, message) - `resultId` = ss.data - -proc generateDeserializeProc(desc: NimNode): NimNode = - let - mtype = ident(getMessageName(desc)) - procName = postfix(ident("new" & getMessageName(desc)), "*") - reader = ident("read" & getMessageName(desc)) - resultId = ident("result") - - result = quote do: - proc `procName`(data: string): `mtype` = - let - ss = newStringStream(data) - pbs = newProtobufStream(ss) - `resultId` = `reader`(pbs) - -macro generateMessageProcs*(x: typed): typed = - let - desc = getImpl(symbol(x)) - - result = newStmtList( - generateNewMessageProc(desc), - ) - - for field in fields(desc): - add(result, generateClearFieldProc(desc, field)) - add(result, generateHasFieldProc(desc, field)) - add(result, generateSetFieldProc(desc, field)) - - if isRepeated(field): - add(result, generateAddToFieldProc(desc, field)) - - add(result, generateWriteMessageProc(desc)) - add(result, generateReadMessageProc(desc)) - add(result, generateSizeOfMessageProc(desc)) - add(result, generateSerializeProc(desc)) - add(result, generateDeserializeProc(desc)) - - when defined(debug): - hint(repr(result)) - -macro generateEnumType*(x: typed): typed = - let - impl = getImpl(symbol(x)) - name = $findColonExpr(impl, "name")[1] - values = findColonExpr(impl, "values")[1] - - let enumTy = nnkEnumTy.newTree(newEmptyNode()) - - for valueNode in values: - let - name = $findColonExpr(valueNode, "name")[1] - number = findColonExpr(valueNode, "number")[1] - - add(enumTy, nnkEnumFieldDef.newTree(ident(name), number)) - - result = newStmtList(nnkTypeSection.newTree( - nnkTypeDef.newTree( - nnkPragmaExpr.newTree( - postfix(ident(name), "*"), - nnkPragma.newTree(ident("pure")) - ), - newEmptyNode(), - enumTy - ) - )) - - when defined(debug): - hint(repr(result)) - -macro generateEnumProcs*(x: typed): typed = - let - impl = getImpl(symbol(x)) - name = $findColonExpr(impl, "name")[1] - nameId = ident(name) - values = findColonExpr(impl, "values")[1] - readProc = postfix(ident("read" & name), "*") - writeProc = postfix(ident("write" & name), "*") - sizeProc = postfix(ident("sizeOf" & name), "*") - resultId = ident("result") - - result = newStmtList() - - add(result, quote do: - proc `readProc`(stream: ProtobufStream): `nameId` = - `resultId` = `nameId`(readUInt32(stream)) - - proc `writeProc`(stream: ProtobufStream, value: `nameId`) = - writeEnum(stream, value) - - proc `sizeProc`(value: `nameId`): uint64 = - `resultId` = sizeOfUInt32(uint32(value)) - ) - - when defined(debug): - hint(repr(result)) -- cgit v1.2.3