aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--generator/protoc_gen_nim.nim161
-rw-r--r--src/protobuf/protobuf.nim94
2 files changed, 179 insertions, 76 deletions
diff --git a/generator/protoc_gen_nim.nim b/generator/protoc_gen_nim.nim
index 1304ed9..3680de7 100644
--- a/generator/protoc_gen_nim.nim
+++ b/generator/protoc_gen_nim.nim
@@ -510,6 +510,9 @@ iterator genProcs(e: Enum): string =
yield &"proc write{e.names}*(stream: ProtobufStream, value: {e.names}) ="
yield indent(&"writeUInt32(stream, uint32(value))", 4)
yield ""
+ yield &"proc write{e.names}*(stream: ProtobufStream, value: {e.names}, fieldNumber: int) ="
+ yield indent(&"writeUInt32(stream, uint32(value), fieldNumber)", 4)
+ yield ""
yield &"proc sizeOf{e.names}*(value: {e.names}): uint64 ="
yield indent(&"sizeOfUInt32(uint32(value))", 4)
@@ -531,23 +534,33 @@ iterator oneofSiblings(field: Field): Field =
iterator genClearFieldProc(msg: Message, field: Field): string =
yield &"proc clear{field.name}*(message: {msg.names}) ="
yield indent(&"message.{field.accessor} = {defaultValue(field)}", 4)
- yield indent(&"excl(message.hasField, {field.number})", 4)
+ var numbers: seq[int] = @[field.number]
for sibling in oneofSiblings(field):
- yield indent(&"excl(message.hasField, {sibling.number})", 4)
+ add(numbers, sibling.number)
+ yield indent(&"excl(message.hasField, [{join(numbers, \", \")}])", 4)
yield ""
iterator genHasFieldProc(msg: Message, field: Field): string =
# TODO: if map/seq, check also if there are values!
yield &"proc has{field.name}*(message: {msg.names}): bool ="
- yield indent(&"result = contains(message.hasField, {field.number})", 4)
+ var check = indent(&"result = contains(message.hasField, {field.number})", 4)
+ if isRepeated(field) or isMapEntry(field):
+ check = &"{check} or (len(message.{field.accessor}) > 0)"
+ yield check
+ # elif isMapEntry(field):
+ # base = &"{base} or (len(message.{}"
+ # yield indent(&"result = contains(message.hasField, {field.number})", 4)
yield ""
iterator genSetFieldProc(msg: Message, field: Field): string =
yield &"proc set{field.name}*(message: {msg.names}, value: {field.fullType}) ="
yield indent(&"message.{field.accessor} = value", 4)
yield indent(&"incl(message.hasField, {field.number})", 4)
+ var numbers: seq[int] = @[]
for sibling in oneofSiblings(field):
- yield indent(&"excl(message.hasField, {sibling.number})", 4)
+ add(numbers, sibling.number)
+ if len(numbers) > 0:
+ yield indent(&"excl(message.hasField, [{join(numbers, \", \")}])", 4)
yield ""
iterator genAddToFieldProc(msg: Message, field: Field): string =
@@ -572,13 +585,12 @@ iterator genWriteMapKVProc(msg: Message): string =
yield &"proc write{msg.names}KV(stream: ProtobufStream, key: {key.fullType}, value: {value.fullType}) ="
- yield indent(&"writeTag(stream, {key.number}, {wiretypeStr(key)})", 4)
- yield indent(&"write{key.typeName}(stream, key)", 4)
+ yield indent(&"write{key.typeName}(stream, key, {key.number})", 4)
- yield indent(&"writeTag(stream, {value.number}, {wiretypeStr(value)})", 4)
if isMessage(value):
- yield indent(&"writeVarint(stream, sizeOf{value.typeName}(value))", 4)
- yield indent(&"write{value.typeName}(stream, value)", 4)
+ yield indent(&"writeMessage(stream, value, {value.number})", 4)
+ else:
+ yield indent(&"write{value.typeName}(stream, value, {value.number})", 4)
yield ""
@@ -600,16 +612,16 @@ iterator genWriteMessageProc(msg: Message): string =
yield indent(&"{writer}(stream, value)", 12)
else:
yield indent(&"for value in message.{field.name}:", 4)
- yield indent(&"writeTag(stream, {field.number}, {wiretypeStr(field)})", 8)
if isMessage(field):
- yield indent(&"writeVarint(stream, sizeOf{field.typeName}(value))", 8)
- yield indent(&"{writer}(stream, value)", 8)
+ yield indent(&"writeMessage(stream, value, {field.number})", 8)
+ else:
+ yield indent(&"{writer}(stream, value, {field.number})", 8)
else:
yield indent(&"if has{field.name}(message):", 4)
- yield indent(&"writeTag(stream, {field.number}, {wiretypeStr(field)})", 8)
if isMessage(field):
- yield indent(&"writeVarint(stream, sizeOf{field.typeName}(message.{field.accessor}))", 8)
- yield indent(&"{writer}(stream, message.{field.accessor})", 8)
+ yield indent(&"writeMessage(stream, message.{field.accessor}, {field.number})", 8)
+ else:
+ yield indent(&"{writer}(stream, message.{field.accessor}, {field.number})", 8)
if len(msg.fields) == 0:
yield indent("discard", 4)
@@ -725,13 +737,18 @@ iterator genSizeOfMapKVProc(message: Message): string =
value = mapValueField(message)
yield &"proc sizeOf{message.names}KV(key: {key.fullType}, value: {value.fullType}): uint64 ="
+
+ # Key (cannot be message or other complex field)
+ yield indent(&"result = result + sizeOfTag({key.number}, {key.wiretypeStr})", 4)
yield indent(&"result = result + sizeOf{key.typeName}(key)", 4)
- yield indent(&"result = result + sizeOfUInt32(uint32(makeTag({key.number}, {key.wiretypeStr})))", 4)
- yield indent(&"let valueSize = sizeOf{value.typeName}(value)", 4)
- yield indent(&"result = result + valueSize", 4)
- yield indent(&"result = result + sizeOfUInt32(uint32(makeTag({value.number}, {value.wiretypeStr})))", 4)
+
+ # Value
+ yield indent(&"result = result + sizeOfTag({value.number}, {value.wiretypeStr})", 4)
if isMessage(value):
- yield indent(&"result = result + sizeOfUInt64(valueSize)", 4)
+ yield indent(&"result = result + sizeOfLengthDelimited(sizeOf{value.typeName}(value))", 4)
+ else:
+ yield indent(&"result = result + sizeOf{value.typeName}(value)", 4)
+
yield ""
iterator genSizeOfMessageProc(msg: Message): string =
@@ -742,36 +759,27 @@ iterator genSizeOfMessageProc(msg: Message): string =
yield indent(&"var sizeOfKV = 0'u64", 8)
yield indent(&"for key, value in message.{field.name}:", 8)
yield indent(&"sizeOfKV = sizeOfKV + sizeOf{field.typeName}KV(key, value)", 12)
- yield indent(&"let sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, {wiretypeStr(field)})))", 8)
- yield indent("result = result + sizeOfKV + sizeOfTag + sizeOfUInt64(sizeOfKV)", 8)
+ yield indent(&"result = result + sizeOfTag({field.number}, {field.wiretypeStr})", 8)
+ yield indent(&"result = result + sizeOfLengthDelimited(sizeOfKV)", 8)
elif isRepeated(field):
if isNumeric(field):
- yield indent(&"""
-if has{field.name}(message):
- let
- sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, WireType.LengthDelimited)))
- sizeOfData = packedFieldSize(message.{field.name}, {field.fieldTypeStr})
- sizeOfSize = sizeOfUInt64(sizeOfData)
- result = sizeOfTag + sizeOfData + sizeOfSize""", 4)
+ yield indent(&"if has{field.name}(message):", 4)
+ yield indent(&"result = result + sizeOfTag({field.number}, WireType.LengthDelimited)", 8)
+ yield indent(&"result = result + sizeOfLengthDelimited(packedFieldSize(message.{field.name}, {field.fieldTypeStr}))", 8)
else:
- yield indent(&"""
-for value in message.{field.name}:
- let
- sizeOfValue = sizeOf{field.typeName}(value)
- sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, {wiretypeStr(field)})))
- result = result + sizeOfValue + sizeOfTag
-""", 4)
+ yield indent(&"for value in message.{field.name}:", 4)
+ yield indent(&"result = result + sizeOfTag({field.number}, {field.wiretypeStr})", 8)
if isMessage(field):
- yield indent("result = result + sizeOfUInt64(sizeOfValue)", 8)
+ yield indent(&"result = result + sizeOfLengthDelimited(sizeOf{field.typeName}(value))", 8)
+ else:
+ yield indent(&"result = result + sizeOf{field.typeName}(value)", 8)
else:
- yield indent(&"""
-if has{field.name}(message):
- let
- sizeOfField = sizeOf{field.typeName}(message.{field.accessor})
- sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, {wiretypeStr(field)})))
- result = result + sizeOfField + sizeOfTag""", 4)
+ yield indent(&"if has{field.name}(message):", 4)
+ yield indent(&"result = result + sizeOfTag({field.number}, {field.wiretypeStr})", 8)
if isMessage(field):
- yield indent("result = result + sizeOfUInt64(sizeOfField)", 8)
+ yield indent(&"result = result + sizeOfLengthDelimited(sizeOf{field.typeName}(message.{field.accessor}))", 8)
+ else:
+ yield indent(&"result = result + sizeOf{field.typeName}(message.{field.accessor})", 8)
if len(msg.fields) == 0:
yield indent("result = 0", 4)
@@ -779,10 +787,11 @@ if has{field.name}(message):
yield ""
iterator genMessageProcForwards(msg: Message): string =
- yield &"proc new{msg.names}*(): {msg.names}"
- yield &"proc write{msg.names}*(stream: ProtobufStream, message: {msg.names})"
- yield &"proc read{msg.names}*(stream: ProtobufStream): {msg.names}"
- yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64"
+ if not isMapEntry(msg):
+ yield &"proc new{msg.names}*(): {msg.names}"
+ yield &"proc write{msg.names}*(stream: ProtobufStream, message: {msg.names})"
+ yield &"proc read{msg.names}*(stream: ProtobufStream): {msg.names}"
+ yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64"
if isMapEntry(msg):
let
@@ -794,41 +803,41 @@ iterator genMessageProcForwards(msg: Message): string =
yield &"proc sizeOf{msg.names}KV(key: {key.fullType}, value: {value.fullType}): uint64"
iterator genProcs(msg: Message): string =
- for line in genNewMessageProc(msg): yield line
-
- for field in msg.fields:
- for line in genClearFieldProc(msg, field): yield line
- for line in genHasFieldProc(msg, field): yield line
- for line in genSetFieldProc(msg, field): yield line
-
- if isRepeated(field) and not isMapEntry(field):
- for line in genAddToFieldProc(msg, field): yield line
-
- for line in genFieldAccessorProcs(msg, field): yield line
-
if isMapEntry(msg):
for line in genSizeOfMapKVProc(msg): yield line
for line in genWriteMapKVProc(msg): yield line
for line in genReadMapKVProc(msg): yield line
+ else:
+ for line in genNewMessageProc(msg): yield line
- for line in genSizeOfMessageProc(msg): yield line
- for line in genWriteMessageProc(msg): yield line
- for line in genReadMessageProc(msg): yield line
+ for field in msg.fields:
+ for line in genClearFieldProc(msg, field): yield line
+ for line in genHasFieldProc(msg, field): yield line
+ for line in genSetFieldProc(msg, field): yield line
- yield &"proc serialize*(message: {msg.names}): string ="
- yield indent("let", 4)
- yield indent("ss = newStringStream()", 8)
- yield indent("pbs = newProtobufStream(ss)", 8)
- yield indent(&"write{msg.names}(pbs, message)", 4)
- yield indent("result = ss.data", 4)
- yield ""
+ if isRepeated(field) and not isMapEntry(field):
+ for line in genAddToFieldProc(msg, field): yield line
- yield &"proc new{msg.names}*(data: string): {msg.names} ="
- yield indent("let", 4)
- yield indent("ss = newStringStream(data)", 8)
- yield indent("pbs = newProtobufStream(ss)", 8)
- yield indent(&"result = read{msg.names}(pbs)", 4)
- yield ""
+ for line in genFieldAccessorProcs(msg, field): yield line
+
+ for line in genSizeOfMessageProc(msg): yield line
+ for line in genWriteMessageProc(msg): yield line
+ for line in genReadMessageProc(msg): yield line
+
+ yield &"proc serialize*(message: {msg.names}): string ="
+ yield indent("let", 4)
+ yield indent("ss = newStringStream()", 8)
+ yield indent("pbs = newProtobufStream(ss)", 8)
+ yield indent(&"write{msg.names}(pbs, message)", 4)
+ yield indent("result = ss.data", 4)
+ yield ""
+
+ yield &"proc new{msg.names}*(data: string): {msg.names} ="
+ yield indent("let", 4)
+ yield indent("ss = newStringStream(data)", 8)
+ yield indent("pbs = newProtobufStream(ss)", 8)
+ yield indent(&"result = read{msg.names}(pbs)", 4)
+ yield ""
proc processFile(filename: string, fdesc: FileDescriptorProto,
otherFiles: TableRef[string, ProtoFile]): ProcessedFile =
diff --git a/src/protobuf/protobuf.nim b/src/protobuf/protobuf.nim
index 98bb7a4..d15bfc7 100644
--- a/src/protobuf/protobuf.nim
+++ b/src/protobuf/protobuf.nim
@@ -1,4 +1,6 @@
import endians
+import intsets
+import macros
import streams
import strutils
@@ -190,42 +192,70 @@ proc readTag*(stream: ProtobufStream): Tag =
proc writeInt32*(stream: ProtobufStream, n: int32) =
writeVarint(stream, n.uint64)
+proc writeInt32*(stream: ProtobufStream, n: int32, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Varint)
+ writeInt32(stream, n)
+
proc readInt32*(stream: ProtobufStream): int32 =
result = readVarint(stream).int32
proc writeSInt32*(stream: ProtobufStream, n: int32) =
writeVarint(stream, zigzagEncode(n))
+proc writeSInt32*(stream: ProtobufStream, n: int32, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Varint)
+ writeVarint(stream, zigzagEncode(n))
+
proc readSInt32*(stream: ProtobufStream): int32 =
result = zigzagDecode(readVarint(stream).uint32)
proc writeUInt32*(stream: ProtobufStream, n: uint32) =
writeVarint(stream, n)
+proc writeUInt32*(stream: ProtobufStream, n: uint32, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Varint)
+ writeVarint(stream, n)
+
proc readUInt32*(stream: ProtobufStream): uint32 =
result = readVarint(stream).uint32
proc writeInt64*(stream: ProtobufStream, n: int64) =
writeVarint(stream, n.uint64)
+proc writeInt64*(stream: ProtobufStream, n: int64, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Varint)
+ writeVarint(stream, n.uint64)
+
proc readInt64*(stream: ProtobufStream): int64 =
result = readVarint(stream).int64
proc writeSInt64*(stream: ProtobufStream, n: int64) =
writeVarint(stream, zigzagEncode(n))
+proc writeSInt64*(stream: ProtobufStream, n: int64, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Varint)
+ writeVarint(stream, zigzagEncode(n))
+
proc readSInt64*(stream: ProtobufStream): int64 =
result = zigzagDecode(readVarint(stream))
proc writeUInt64*(stream: ProtobufStream, n: uint64) =
writeVarint(stream, n)
+proc writeUInt64*(stream: ProtobufStream, n: uint64, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Varint)
+ writeVarint(stream, n)
+
proc readUInt64*(stream: ProtobufStream): uint64 =
result = readVarint(stream)
proc writeBool*(stream: ProtobufStream, value: bool) =
writeVarint(stream, value.uint32)
+proc writeBool*(stream: ProtobufStream, n: bool, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Varint)
+ writeVarint(stream, n.uint32)
+
proc readBool*(stream: ProtobufStream): bool =
result = readVarint(stream).bool
@@ -238,6 +268,10 @@ proc writeFixed64*(stream: ProtobufStream, value: uint64) =
write(stream, output)
+proc writeFixed64*(stream: ProtobufStream, n: uint64, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Fixed64)
+ writeFixed64(stream, n)
+
proc readFixed64*(stream: ProtobufStream): uint64 =
var tmp: uint64
if readData(stream, addr(tmp), sizeof(tmp)) != sizeof(tmp):
@@ -247,6 +281,10 @@ proc readFixed64*(stream: ProtobufStream): uint64 =
proc writeSFixed64*(stream: ProtobufStream, value: int64) =
writeFixed64(stream, cast[uint64](value))
+proc writeSFixed64*(stream: ProtobufStream, value: int64, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Fixed64)
+ writeSFixed64(stream, value)
+
proc readSFixed64*(stream: ProtobufStream): int64 =
result = cast[int64](readFixed64(stream))
@@ -259,6 +297,10 @@ proc writeDouble*(stream: ProtobufStream, value: float64) =
write(stream, output)
+proc writeDouble*(stream: ProtobufStream, value: float64, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Fixed64)
+ writeDouble(stream, value)
+
proc readDouble*(stream: ProtobufStream): float64 =
var tmp: uint64
if readData(stream, addr(tmp), sizeof(tmp)) != sizeof(tmp):
@@ -274,6 +316,10 @@ proc writeFixed32*(stream: ProtobufStream, value: uint32) =
write(stream, output)
+proc writeFixed32*(stream: ProtobufStream, value: uint32, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Fixed32)
+ writeFixed32(stream, value)
+
proc readFixed32*(stream: ProtobufStream): uint32 =
var tmp: uint32
if readData(stream, addr(tmp), sizeof(tmp)) != sizeof(tmp):
@@ -283,6 +329,10 @@ proc readFixed32*(stream: ProtobufStream): uint32 =
proc writeSFixed32*(stream: ProtobufStream, value: int32) =
writeFixed32(stream, cast[uint32](value))
+proc writeSFixed32*(stream: ProtobufStream, value: int32, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Fixed32)
+ writeSFixed32(stream, value)
+
proc readSFixed32*(stream: ProtobufStream): int32 =
result = cast[int32](readFixed32(stream))
@@ -295,6 +345,10 @@ proc writeFloat*(stream: ProtobufStream, value: float32) =
write(stream, output)
+proc writeFloat*(stream: ProtobufStream, value: float32, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.Fixed32)
+ writeFloat(stream, value)
+
proc readFloat*(stream: ProtobufStream): float32 =
var tmp: float32
if readData(stream, addr(tmp), sizeof(tmp)) != sizeof(tmp):
@@ -305,9 +359,16 @@ proc writeString*(stream: ProtobufStream, s: string) =
writeUInt64(stream, len(s).uint64)
write(stream, s)
+proc writeString*(stream: ProtobufStream, s: string, fieldNumber: int) =
+ writeTag(stream, fieldNumber, WireType.LengthDelimited)
+ writeString(stream, s)
+
proc writeBytes*(stream: ProtobufStream, s: bytes) =
writeString(stream, string(s))
+proc writeBytes*(stream: ProtobufStream, s: bytes, fieldNumber: int) =
+ writeString(stream, string(s), fieldNumber)
+
proc safeReadStr*(stream: Stream, size: int): string =
result = newString(size)
if readData(stream, addr(result[0]), size) != size:
@@ -400,6 +461,12 @@ proc sizeOfSInt64*(value: int64): uint64 =
proc sizeOfEnum*[T](value: T): uint64 =
result = sizeOfUInt32(value.uint32)
+proc sizeOfLengthDelimited*(size: uint64): uint64 =
+ result = size + sizeOfVarint(size)
+
+proc sizeOfTag*(fieldNumber: int, wiretype: WireType): uint64 =
+ result = sizeOfUInt32(uint32(makeTag(fieldNumber, wiretype)))
+
proc skipField*(stream: ProtobufStream, wiretype: WireType) =
case wiretype
of WireType.Varint:
@@ -421,3 +488,30 @@ proc expectWireType*(actual: WireType, expected: varargs[WireType]) =
let message = "Got wiretype " & $actual & " but expected: " &
join(expected, ", ")
raise newException(UnexpectedWireTypeError, message)
+
+macro writeMessage*(stream: ProtobufStream, message: typed, fieldNumber: int): typed =
+ ## Write a message to a stream with tag and length.
+ let t = getTypeInst(message)
+ result = newStmtList(
+ newCall(
+ ident("writeTag"),
+ stream,
+ fieldNumber,
+ newDotExpr(ident("WireType"), ident("LengthDelimited"))
+ ),
+ newCall(
+ ident("writeVarint"),
+ stream,
+ newCall(ident("sizeOf" & $t), message)
+ ),
+ newCall(
+ ident("write" & $t),
+ stream,
+ message
+ )
+ )
+
+proc excl*(s: var IntSet, values: openArray[int]) =
+ ## Exclude multiple values from an IntSet.
+ for value in values:
+ excl(s, value)