diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2018-04-03 20:22:08 +0300 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2018-04-03 20:22:08 +0300 |
| commit | 54280463ac487e331daa43453058d02599f85102 (patch) | |
| tree | 8790efccb67cefcb08359131745cbf68552f0d84 | |
| parent | e2cfc6717ae8acd9c47a132fa2cbc8446b5920c5 (diff) | |
| download | nimpb-54280463ac487e331daa43453058d02599f85102.tar.gz nimpb-54280463ac487e331daa43453058d02599f85102.zip | |
Generate less and more cleaner code
| -rw-r--r-- | generator/protoc_gen_nim.nim | 161 | ||||
| -rw-r--r-- | src/protobuf/protobuf.nim | 94 |
2 files changed, 179 insertions, 76 deletions
diff --git a/generator/protoc_gen_nim.nim b/generator/protoc_gen_nim.nim index 1304ed9..3680de7 100644 --- a/generator/protoc_gen_nim.nim +++ b/generator/protoc_gen_nim.nim @@ -510,6 +510,9 @@ iterator genProcs(e: Enum): string = yield &"proc write{e.names}*(stream: ProtobufStream, value: {e.names}) =" yield indent(&"writeUInt32(stream, uint32(value))", 4) yield "" + yield &"proc write{e.names}*(stream: ProtobufStream, value: {e.names}, fieldNumber: int) =" + yield indent(&"writeUInt32(stream, uint32(value), fieldNumber)", 4) + yield "" yield &"proc sizeOf{e.names}*(value: {e.names}): uint64 =" yield indent(&"sizeOfUInt32(uint32(value))", 4) @@ -531,23 +534,33 @@ iterator oneofSiblings(field: Field): Field = iterator genClearFieldProc(msg: Message, field: Field): string = yield &"proc clear{field.name}*(message: {msg.names}) =" yield indent(&"message.{field.accessor} = {defaultValue(field)}", 4) - yield indent(&"excl(message.hasField, {field.number})", 4) + var numbers: seq[int] = @[field.number] for sibling in oneofSiblings(field): - yield indent(&"excl(message.hasField, {sibling.number})", 4) + add(numbers, sibling.number) + yield indent(&"excl(message.hasField, [{join(numbers, \", \")}])", 4) yield "" iterator genHasFieldProc(msg: Message, field: Field): string = # TODO: if map/seq, check also if there are values! yield &"proc has{field.name}*(message: {msg.names}): bool =" - yield indent(&"result = contains(message.hasField, {field.number})", 4) + var check = indent(&"result = contains(message.hasField, {field.number})", 4) + if isRepeated(field) or isMapEntry(field): + check = &"{check} or (len(message.{field.accessor}) > 0)" + yield check + # elif isMapEntry(field): + # base = &"{base} or (len(message.{}" + # yield indent(&"result = contains(message.hasField, {field.number})", 4) yield "" iterator genSetFieldProc(msg: Message, field: Field): string = yield &"proc set{field.name}*(message: {msg.names}, value: {field.fullType}) =" yield indent(&"message.{field.accessor} = value", 4) yield indent(&"incl(message.hasField, {field.number})", 4) + var numbers: seq[int] = @[] for sibling in oneofSiblings(field): - yield indent(&"excl(message.hasField, {sibling.number})", 4) + add(numbers, sibling.number) + if len(numbers) > 0: + yield indent(&"excl(message.hasField, [{join(numbers, \", \")}])", 4) yield "" iterator genAddToFieldProc(msg: Message, field: Field): string = @@ -572,13 +585,12 @@ iterator genWriteMapKVProc(msg: Message): string = yield &"proc write{msg.names}KV(stream: ProtobufStream, key: {key.fullType}, value: {value.fullType}) =" - yield indent(&"writeTag(stream, {key.number}, {wiretypeStr(key)})", 4) - yield indent(&"write{key.typeName}(stream, key)", 4) + yield indent(&"write{key.typeName}(stream, key, {key.number})", 4) - yield indent(&"writeTag(stream, {value.number}, {wiretypeStr(value)})", 4) if isMessage(value): - yield indent(&"writeVarint(stream, sizeOf{value.typeName}(value))", 4) - yield indent(&"write{value.typeName}(stream, value)", 4) + yield indent(&"writeMessage(stream, value, {value.number})", 4) + else: + yield indent(&"write{value.typeName}(stream, value, {value.number})", 4) yield "" @@ -600,16 +612,16 @@ iterator genWriteMessageProc(msg: Message): string = yield indent(&"{writer}(stream, value)", 12) else: yield indent(&"for value in message.{field.name}:", 4) - yield indent(&"writeTag(stream, {field.number}, {wiretypeStr(field)})", 8) if isMessage(field): - yield indent(&"writeVarint(stream, sizeOf{field.typeName}(value))", 8) - yield indent(&"{writer}(stream, value)", 8) + yield indent(&"writeMessage(stream, value, {field.number})", 8) + else: + yield indent(&"{writer}(stream, value, {field.number})", 8) else: yield indent(&"if has{field.name}(message):", 4) - yield indent(&"writeTag(stream, {field.number}, {wiretypeStr(field)})", 8) if isMessage(field): - yield indent(&"writeVarint(stream, sizeOf{field.typeName}(message.{field.accessor}))", 8) - yield indent(&"{writer}(stream, message.{field.accessor})", 8) + yield indent(&"writeMessage(stream, message.{field.accessor}, {field.number})", 8) + else: + yield indent(&"{writer}(stream, message.{field.accessor}, {field.number})", 8) if len(msg.fields) == 0: yield indent("discard", 4) @@ -725,13 +737,18 @@ iterator genSizeOfMapKVProc(message: Message): string = value = mapValueField(message) yield &"proc sizeOf{message.names}KV(key: {key.fullType}, value: {value.fullType}): uint64 =" + + # Key (cannot be message or other complex field) + yield indent(&"result = result + sizeOfTag({key.number}, {key.wiretypeStr})", 4) yield indent(&"result = result + sizeOf{key.typeName}(key)", 4) - yield indent(&"result = result + sizeOfUInt32(uint32(makeTag({key.number}, {key.wiretypeStr})))", 4) - yield indent(&"let valueSize = sizeOf{value.typeName}(value)", 4) - yield indent(&"result = result + valueSize", 4) - yield indent(&"result = result + sizeOfUInt32(uint32(makeTag({value.number}, {value.wiretypeStr})))", 4) + + # Value + yield indent(&"result = result + sizeOfTag({value.number}, {value.wiretypeStr})", 4) if isMessage(value): - yield indent(&"result = result + sizeOfUInt64(valueSize)", 4) + yield indent(&"result = result + sizeOfLengthDelimited(sizeOf{value.typeName}(value))", 4) + else: + yield indent(&"result = result + sizeOf{value.typeName}(value)", 4) + yield "" iterator genSizeOfMessageProc(msg: Message): string = @@ -742,36 +759,27 @@ iterator genSizeOfMessageProc(msg: Message): string = yield indent(&"var sizeOfKV = 0'u64", 8) yield indent(&"for key, value in message.{field.name}:", 8) yield indent(&"sizeOfKV = sizeOfKV + sizeOf{field.typeName}KV(key, value)", 12) - yield indent(&"let sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, {wiretypeStr(field)})))", 8) - yield indent("result = result + sizeOfKV + sizeOfTag + sizeOfUInt64(sizeOfKV)", 8) + yield indent(&"result = result + sizeOfTag({field.number}, {field.wiretypeStr})", 8) + yield indent(&"result = result + sizeOfLengthDelimited(sizeOfKV)", 8) elif isRepeated(field): if isNumeric(field): - yield indent(&""" -if has{field.name}(message): - let - sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, WireType.LengthDelimited))) - sizeOfData = packedFieldSize(message.{field.name}, {field.fieldTypeStr}) - sizeOfSize = sizeOfUInt64(sizeOfData) - result = sizeOfTag + sizeOfData + sizeOfSize""", 4) + yield indent(&"if has{field.name}(message):", 4) + yield indent(&"result = result + sizeOfTag({field.number}, WireType.LengthDelimited)", 8) + yield indent(&"result = result + sizeOfLengthDelimited(packedFieldSize(message.{field.name}, {field.fieldTypeStr}))", 8) else: - yield indent(&""" -for value in message.{field.name}: - let - sizeOfValue = sizeOf{field.typeName}(value) - sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, {wiretypeStr(field)}))) - result = result + sizeOfValue + sizeOfTag -""", 4) + yield indent(&"for value in message.{field.name}:", 4) + yield indent(&"result = result + sizeOfTag({field.number}, {field.wiretypeStr})", 8) if isMessage(field): - yield indent("result = result + sizeOfUInt64(sizeOfValue)", 8) + yield indent(&"result = result + sizeOfLengthDelimited(sizeOf{field.typeName}(value))", 8) + else: + yield indent(&"result = result + sizeOf{field.typeName}(value)", 8) else: - yield indent(&""" -if has{field.name}(message): - let - sizeOfField = sizeOf{field.typeName}(message.{field.accessor}) - sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, {wiretypeStr(field)}))) - result = result + sizeOfField + sizeOfTag""", 4) + yield indent(&"if has{field.name}(message):", 4) + yield indent(&"result = result + sizeOfTag({field.number}, {field.wiretypeStr})", 8) if isMessage(field): - yield indent("result = result + sizeOfUInt64(sizeOfField)", 8) + yield indent(&"result = result + sizeOfLengthDelimited(sizeOf{field.typeName}(message.{field.accessor}))", 8) + else: + yield indent(&"result = result + sizeOf{field.typeName}(message.{field.accessor})", 8) if len(msg.fields) == 0: yield indent("result = 0", 4) @@ -779,10 +787,11 @@ if has{field.name}(message): yield "" iterator genMessageProcForwards(msg: Message): string = - yield &"proc new{msg.names}*(): {msg.names}" - yield &"proc write{msg.names}*(stream: ProtobufStream, message: {msg.names})" - yield &"proc read{msg.names}*(stream: ProtobufStream): {msg.names}" - yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64" + if not isMapEntry(msg): + yield &"proc new{msg.names}*(): {msg.names}" + yield &"proc write{msg.names}*(stream: ProtobufStream, message: {msg.names})" + yield &"proc read{msg.names}*(stream: ProtobufStream): {msg.names}" + yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64" if isMapEntry(msg): let @@ -794,41 +803,41 @@ iterator genMessageProcForwards(msg: Message): string = yield &"proc sizeOf{msg.names}KV(key: {key.fullType}, value: {value.fullType}): uint64" iterator genProcs(msg: Message): string = - for line in genNewMessageProc(msg): yield line - - for field in msg.fields: - for line in genClearFieldProc(msg, field): yield line - for line in genHasFieldProc(msg, field): yield line - for line in genSetFieldProc(msg, field): yield line - - if isRepeated(field) and not isMapEntry(field): - for line in genAddToFieldProc(msg, field): yield line - - for line in genFieldAccessorProcs(msg, field): yield line - if isMapEntry(msg): for line in genSizeOfMapKVProc(msg): yield line for line in genWriteMapKVProc(msg): yield line for line in genReadMapKVProc(msg): yield line + else: + for line in genNewMessageProc(msg): yield line - for line in genSizeOfMessageProc(msg): yield line - for line in genWriteMessageProc(msg): yield line - for line in genReadMessageProc(msg): yield line + for field in msg.fields: + for line in genClearFieldProc(msg, field): yield line + for line in genHasFieldProc(msg, field): yield line + for line in genSetFieldProc(msg, field): yield line - yield &"proc serialize*(message: {msg.names}): string =" - yield indent("let", 4) - yield indent("ss = newStringStream()", 8) - yield indent("pbs = newProtobufStream(ss)", 8) - yield indent(&"write{msg.names}(pbs, message)", 4) - yield indent("result = ss.data", 4) - yield "" + if isRepeated(field) and not isMapEntry(field): + for line in genAddToFieldProc(msg, field): yield line - yield &"proc new{msg.names}*(data: string): {msg.names} =" - yield indent("let", 4) - yield indent("ss = newStringStream(data)", 8) - yield indent("pbs = newProtobufStream(ss)", 8) - yield indent(&"result = read{msg.names}(pbs)", 4) - yield "" + for line in genFieldAccessorProcs(msg, field): yield line + + for line in genSizeOfMessageProc(msg): yield line + for line in genWriteMessageProc(msg): yield line + for line in genReadMessageProc(msg): yield line + + yield &"proc serialize*(message: {msg.names}): string =" + yield indent("let", 4) + yield indent("ss = newStringStream()", 8) + yield indent("pbs = newProtobufStream(ss)", 8) + yield indent(&"write{msg.names}(pbs, message)", 4) + yield indent("result = ss.data", 4) + yield "" + + yield &"proc new{msg.names}*(data: string): {msg.names} =" + yield indent("let", 4) + yield indent("ss = newStringStream(data)", 8) + yield indent("pbs = newProtobufStream(ss)", 8) + yield indent(&"result = read{msg.names}(pbs)", 4) + yield "" proc processFile(filename: string, fdesc: FileDescriptorProto, otherFiles: TableRef[string, ProtoFile]): ProcessedFile = diff --git a/src/protobuf/protobuf.nim b/src/protobuf/protobuf.nim index 98bb7a4..d15bfc7 100644 --- a/src/protobuf/protobuf.nim +++ b/src/protobuf/protobuf.nim @@ -1,4 +1,6 @@ import endians +import intsets +import macros import streams import strutils @@ -190,42 +192,70 @@ proc readTag*(stream: ProtobufStream): Tag = proc writeInt32*(stream: ProtobufStream, n: int32) = writeVarint(stream, n.uint64) +proc writeInt32*(stream: ProtobufStream, n: int32, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Varint) + writeInt32(stream, n) + proc readInt32*(stream: ProtobufStream): int32 = result = readVarint(stream).int32 proc writeSInt32*(stream: ProtobufStream, n: int32) = writeVarint(stream, zigzagEncode(n)) +proc writeSInt32*(stream: ProtobufStream, n: int32, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Varint) + writeVarint(stream, zigzagEncode(n)) + proc readSInt32*(stream: ProtobufStream): int32 = result = zigzagDecode(readVarint(stream).uint32) proc writeUInt32*(stream: ProtobufStream, n: uint32) = writeVarint(stream, n) +proc writeUInt32*(stream: ProtobufStream, n: uint32, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Varint) + writeVarint(stream, n) + proc readUInt32*(stream: ProtobufStream): uint32 = result = readVarint(stream).uint32 proc writeInt64*(stream: ProtobufStream, n: int64) = writeVarint(stream, n.uint64) +proc writeInt64*(stream: ProtobufStream, n: int64, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Varint) + writeVarint(stream, n.uint64) + proc readInt64*(stream: ProtobufStream): int64 = result = readVarint(stream).int64 proc writeSInt64*(stream: ProtobufStream, n: int64) = writeVarint(stream, zigzagEncode(n)) +proc writeSInt64*(stream: ProtobufStream, n: int64, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Varint) + writeVarint(stream, zigzagEncode(n)) + proc readSInt64*(stream: ProtobufStream): int64 = result = zigzagDecode(readVarint(stream)) proc writeUInt64*(stream: ProtobufStream, n: uint64) = writeVarint(stream, n) +proc writeUInt64*(stream: ProtobufStream, n: uint64, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Varint) + writeVarint(stream, n) + proc readUInt64*(stream: ProtobufStream): uint64 = result = readVarint(stream) proc writeBool*(stream: ProtobufStream, value: bool) = writeVarint(stream, value.uint32) +proc writeBool*(stream: ProtobufStream, n: bool, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Varint) + writeVarint(stream, n.uint32) + proc readBool*(stream: ProtobufStream): bool = result = readVarint(stream).bool @@ -238,6 +268,10 @@ proc writeFixed64*(stream: ProtobufStream, value: uint64) = write(stream, output) +proc writeFixed64*(stream: ProtobufStream, n: uint64, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Fixed64) + writeFixed64(stream, n) + proc readFixed64*(stream: ProtobufStream): uint64 = var tmp: uint64 if readData(stream, addr(tmp), sizeof(tmp)) != sizeof(tmp): @@ -247,6 +281,10 @@ proc readFixed64*(stream: ProtobufStream): uint64 = proc writeSFixed64*(stream: ProtobufStream, value: int64) = writeFixed64(stream, cast[uint64](value)) +proc writeSFixed64*(stream: ProtobufStream, value: int64, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Fixed64) + writeSFixed64(stream, value) + proc readSFixed64*(stream: ProtobufStream): int64 = result = cast[int64](readFixed64(stream)) @@ -259,6 +297,10 @@ proc writeDouble*(stream: ProtobufStream, value: float64) = write(stream, output) +proc writeDouble*(stream: ProtobufStream, value: float64, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Fixed64) + writeDouble(stream, value) + proc readDouble*(stream: ProtobufStream): float64 = var tmp: uint64 if readData(stream, addr(tmp), sizeof(tmp)) != sizeof(tmp): @@ -274,6 +316,10 @@ proc writeFixed32*(stream: ProtobufStream, value: uint32) = write(stream, output) +proc writeFixed32*(stream: ProtobufStream, value: uint32, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Fixed32) + writeFixed32(stream, value) + proc readFixed32*(stream: ProtobufStream): uint32 = var tmp: uint32 if readData(stream, addr(tmp), sizeof(tmp)) != sizeof(tmp): @@ -283,6 +329,10 @@ proc readFixed32*(stream: ProtobufStream): uint32 = proc writeSFixed32*(stream: ProtobufStream, value: int32) = writeFixed32(stream, cast[uint32](value)) +proc writeSFixed32*(stream: ProtobufStream, value: int32, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Fixed32) + writeSFixed32(stream, value) + proc readSFixed32*(stream: ProtobufStream): int32 = result = cast[int32](readFixed32(stream)) @@ -295,6 +345,10 @@ proc writeFloat*(stream: ProtobufStream, value: float32) = write(stream, output) +proc writeFloat*(stream: ProtobufStream, value: float32, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.Fixed32) + writeFloat(stream, value) + proc readFloat*(stream: ProtobufStream): float32 = var tmp: float32 if readData(stream, addr(tmp), sizeof(tmp)) != sizeof(tmp): @@ -305,9 +359,16 @@ proc writeString*(stream: ProtobufStream, s: string) = writeUInt64(stream, len(s).uint64) write(stream, s) +proc writeString*(stream: ProtobufStream, s: string, fieldNumber: int) = + writeTag(stream, fieldNumber, WireType.LengthDelimited) + writeString(stream, s) + proc writeBytes*(stream: ProtobufStream, s: bytes) = writeString(stream, string(s)) +proc writeBytes*(stream: ProtobufStream, s: bytes, fieldNumber: int) = + writeString(stream, string(s), fieldNumber) + proc safeReadStr*(stream: Stream, size: int): string = result = newString(size) if readData(stream, addr(result[0]), size) != size: @@ -400,6 +461,12 @@ proc sizeOfSInt64*(value: int64): uint64 = proc sizeOfEnum*[T](value: T): uint64 = result = sizeOfUInt32(value.uint32) +proc sizeOfLengthDelimited*(size: uint64): uint64 = + result = size + sizeOfVarint(size) + +proc sizeOfTag*(fieldNumber: int, wiretype: WireType): uint64 = + result = sizeOfUInt32(uint32(makeTag(fieldNumber, wiretype))) + proc skipField*(stream: ProtobufStream, wiretype: WireType) = case wiretype of WireType.Varint: @@ -421,3 +488,30 @@ proc expectWireType*(actual: WireType, expected: varargs[WireType]) = let message = "Got wiretype " & $actual & " but expected: " & join(expected, ", ") raise newException(UnexpectedWireTypeError, message) + +macro writeMessage*(stream: ProtobufStream, message: typed, fieldNumber: int): typed = + ## Write a message to a stream with tag and length. + let t = getTypeInst(message) + result = newStmtList( + newCall( + ident("writeTag"), + stream, + fieldNumber, + newDotExpr(ident("WireType"), ident("LengthDelimited")) + ), + newCall( + ident("writeVarint"), + stream, + newCall(ident("sizeOf" & $t), message) + ), + newCall( + ident("write" & $t), + stream, + message + ) + ) + +proc excl*(s: var IntSet, values: openArray[int]) = + ## Exclude multiple values from an IntSet. + for value in values: + excl(s, value) |
