diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2018-04-01 16:01:55 +0300 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2018-04-01 16:01:55 +0300 |
| commit | 75ec5d643ce7a10b05bb18de3fc7a3e678183103 (patch) | |
| tree | 7937e8d7ce67ea9d7df0d6d10cb70c9a13ca9e88 | |
| parent | 35864af115c17b721bfdc926e0e49ca26a3c431b (diff) | |
| download | nimpb-75ec5d643ce7a10b05bb18de3fc7a3e678183103.tar.gz nimpb-75ec5d643ce7a10b05bb18de3fc7a3e678183103.zip | |
Add initial native support for maps
| -rw-r--r-- | generator/descriptor_pb.nim | 27 | ||||
| -rw-r--r-- | generator/protoc_gen_nim.nim | 176 |
2 files changed, 191 insertions, 12 deletions
diff --git a/generator/descriptor_pb.nim b/generator/descriptor_pb.nim index 05c21b9..890a61d 100644 --- a/generator/descriptor_pb.nim +++ b/generator/descriptor_pb.nim @@ -128,6 +128,15 @@ const packed: false, oneofIdx: -1, ), + FieldDesc( + name: "options", + number: 7, + ftype: FieldType.Message, + label: FieldLabel.Optional, + typeName: "MessageOptions", + packed: false, + oneofIdx: -1, + ), ] ) @@ -281,6 +290,21 @@ const ] ) + MessageOptionsDesc = MessageDesc( + name: "MessageOptions", + fields: @[ + FieldDesc( + name: "map_entry", + number: 7, + ftype: FieldType.Bool, + label: FieldLabel.Optional, + typeName: "", + packed: false, + oneofIdx: -1, + ), + ] + ) + FieldOptionsDesc = MessageDesc( name: "FieldOptions", fields: @[ @@ -332,6 +356,9 @@ generateMessageProcs(FieldDescriptorProtoDesc) generateMessageType(OneofDescriptorProtoDesc) generateMessageProcs(OneofDescriptorProtoDesc) +generateMessageType(MessageOptionsDesc) +generateMessageProcs(MessageOptionsDesc) + generateMessageType(DescriptorProtoDesc) generateMessageProcs(DescriptorProtoDesc) diff --git a/generator/protoc_gen_nim.nim b/generator/protoc_gen_nim.nim index ed7fac3..4a1e3df 100644 --- a/generator/protoc_gen_nim.nim +++ b/generator/protoc_gen_nim.nim @@ -28,11 +28,13 @@ type typeName: string packed: bool oneof: Oneof + mapEntry: Message Message = ref object names: Names fields: seq[Field] oneofs: seq[Oneof] + mapEntry: bool Oneof = ref object name: string @@ -97,6 +99,12 @@ proc isNumeric(field: Field): bool = result = true else: discard +proc isMapEntry(message: Message): bool = + result = message.mapEntry + +proc isMapEntry(field: Field): bool = + result = field.mapEntry != nil + proc nimTypeName(field: Field): string = case field.ftype of FieldDescriptorProtoType.TypeDouble: result = "float64" @@ -118,6 +126,16 @@ proc nimTypeName(field: Field): string = of FieldDescriptorProtoType.TypeSInt32: result = "int32" of FieldDescriptorProtoType.TypeSInt64: result = "int64" +proc mapKeyType(field: Field): string = + for f in field.mapEntry.fields: + if f.name == "key": + return f.nimTypeName + +proc mapValueType(field: Field): string = + for f in field.mapEntry.fields: + if f.name == "value": + return f.nimTypeName + proc `$`(ft: FieldDescriptorProtoType): string = case ft of FieldDescriptorProtoType.TypeDouble: result = "Double" @@ -140,7 +158,9 @@ proc `$`(ft: FieldDescriptorProtoType): string = of FieldDescriptorProtoType.TypeSInt64: result = "SInt64" proc defaultValue(field: Field): string = - if isRepeated(field): + if isMapEntry(field): + return &"newTable[{field.mapKeyType}, {field.mapValueType}]()" + elif isRepeated(field): return "@[]" case field.ftype @@ -209,11 +229,12 @@ proc newField(file: ProtoFile, message: Message, desc: FieldDescriptorProto): Fi result.ftype = desc.type result.typeName = "" result.packed = false + result.mapEntry = nil if isKeyword(result.name): result.name = "f" & result.name - if isNumeric(result): + if isRepeated(result) and isNumeric(result): if hasOptions(desc): if hasPacked(desc.options): result.packed = desc.options.packed @@ -252,6 +273,10 @@ proc newMessage(file: ProtoFile, names: Names, desc: DescriptorProto): Message = result.names = names result.fields = @[] result.oneofs = @[] + result.mapEntry = false + + if hasMapEntry(desc.options): + result.mapEntry = desc.options.mapEntry log(&"newMessage {$result.names}") @@ -261,6 +286,15 @@ proc newMessage(file: ProtoFile, names: Names, desc: DescriptorProto): Message = for field in desc.field: add(result.fields, newField(file, result, field)) +proc fixMapEntry(file: ProtoFile, message: Message): bool = + for field in message.fields: + for msg in file.messages: + if $msg.names == field.typeName: + if msg.mapEntry: + log(&"fixing map {field.name} {msg.names}") + field.mapEntry = msg + result = true + proc newEnum(names: Names, desc: EnumDescriptorProto): Enum = new(result) @@ -416,9 +450,22 @@ iterator genType(e: Enum): string = yield indent(&"{name} = {number}", 4) proc fullType(field: Field): string = - result = field.nimTypeName - if isRepeated(field): - result = &"seq[{result}]" + if isMapEntry(field): + result = &"TableRef[{field.mapKeyType}, {field.mapValueType}]" + else: + result = field.nimTypeName + if isRepeated(field): + result = &"seq[{result}]" + +proc mapKeyField(message: Message): Field = + for field in message.fields: + if field.name == "key": + return field + +proc mapValueField(message: Message): Field = + for field in message.fields: + if field.name == "value": + return field iterator genType(message: Message): string = yield &"{message.names}* = ref {message.names}Obj" @@ -426,7 +473,9 @@ iterator genType(message: Message): string = yield indent(&"hasField: IntSet", 4) for field in message.fields: - if field.oneof == nil: + if isMapEntry(field): + yield indent(&"{field.name}: TableRef[{mapKeyType(field)}, {mapValueType(field)}]", 4) + elif field.oneof == nil: yield indent(&"{quoteReserved(field.name)}: {field.fullType}", 4) for oneof in message.oneofs: @@ -472,6 +521,7 @@ iterator genClearFieldProc(msg: Message, field: Field): string = yield "" iterator genHasFieldProc(msg: Message, field: Field): string = + # TODO: if map/seq, check also if there are values! yield &"proc has{field.name}*(message: {msg.names}): bool =" yield indent(&"result = contains(message.hasField, {field.number})", 4) yield "" @@ -499,11 +549,33 @@ iterator genFieldAccessorProcs(msg: Message, field: Field): string = yield indent(&"set{field.name}(message, value)", 4) yield "" +iterator genWriteMapKVProc(msg: Message): string = + let + key = mapKeyField(msg) + value = mapValueField(msg) + + yield &"proc write{msg.names}KV(stream: ProtobufStream, key: {key.fullType}, value: {value.fullType}) =" + + yield indent(&"writeTag(stream, {key.number}, {wiretypeStr(key)})", 4) + yield indent(&"write{key.typeName}(stream, key)", 4) + + yield indent(&"writeTag(stream, {value.number}, {wiretypeStr(value)})", 4) + if isMessage(value): + yield indent(&"writeVarint(stream, sizeOf{value.typeName}(value))", 4) + yield indent(&"write{value.typeName}(stream, value)", 4) + + yield "" + iterator genWriteMessageProc(msg: Message): string = yield &"proc write{msg.names}*(stream: ProtobufStream, message: {msg.names}) =" for field in msg.fields: let writer = "write" & field.typeName - if isRepeated(field): + if isMapEntry(field): + yield indent(&"for key, value in message.{field.name}:", 4) + yield indent(&"writeTag(stream, {field.number}, {wiretypeStr(field)})", 8) + yield indent(&"writeVarint(stream, sizeOf{field.typeName}KV(key, value))", 8) + yield indent(&"write{field.typeName}KV(stream, key, value)", 8) + elif isRepeated(field): if field.packed: yield indent(&"if has{field.name}(message):", 4) yield indent(&"writeTag(stream, {field.number}, WireType.LengthDelimited)", 8) @@ -524,6 +596,44 @@ iterator genWriteMessageProc(msg: Message): string = yield indent(&"{writer}(stream, message.{field.accessor})", 8) yield "" +iterator genReadMapKVProc(msg: Message): string = + let + key = mapKeyField(msg) + value = mapValueField(msg) + + yield &"proc read{msg.names}KV(stream: ProtobufStream, tbl: TableRef[{key.fullType}, {value.fullType}]) =" + + yield indent(&"var", 4) + yield indent(&"key: {key.fullType}", 8) + yield indent("gotKey = false", 8) + yield indent(&"value: {value.fullType}", 8) + yield indent("gotValue = false", 8) + yield indent("while not atEnd(stream):", 4) + yield indent("let", 8) + yield indent("tag = readTag(stream)", 12) + yield indent("wireType = getTagWireType(tag)", 12) + yield indent("case getTagFieldNumber(tag)", 8) + yield indent(&"of {key.number}:", 8) + yield indent(&"key = read{key.typeName}(stream)", 12) + yield indent("gotKey = true", 12) + yield indent(&"of {value.number}:", 8) + if isMessage(value): + yield indent("let", 12) + yield indent("size = readVarint(stream)", 16) + yield indent("data = readStr(stream, int(size))", 16) + yield indent("pbs = newProtobufStream(newStringStream(data))", 16) + yield indent(&"value = read{value.typeName}(pbs)", 12) + else: + yield indent(&"value = read{value.typeName}(stream)", 12) + yield indent("gotValue = true", 12) + yield indent("else: skipField(stream, wireType)", 8) + yield indent("if not gotKey:", 4) + yield indent(&"raise newException(Exception, \"missing key ({msg.names})\")", 8) + yield indent("if not gotValue:", 4) + yield indent(&"raise newException(Exception, \"missing value ({msg.names})\")", 8) + yield indent("tbl[key] = value", 4) + yield "" + iterator genReadMessageProc(msg: Message): string = yield &"proc read{msg.names}*(stream: ProtobufStream): {msg.names} =" yield indent(&"result = new{msg.names}()", 4) @@ -542,7 +652,13 @@ iterator genReadMessageProc(msg: Message): string = &"set{field.name}" yield indent(&"of {field.number}:", 8) if isRepeated(field): - if isNumeric(field): + if isMapEntry(field): + yield indent("let", 12) + yield indent("size = readVarint(stream)", 16) + yield indent("data = readStr(stream, int(size))", 16) + yield indent("pbs = newProtobufStream(newStringStream(data))", 16) + yield indent(&"read{field.typeName}KV(pbs, result.{field.name})", 12) + elif isNumeric(field): yield indent("if wireType == WireType.LengthDelimited:", 12) yield indent("let", 16) yield indent("size = readVarint(stream)", 20) @@ -575,10 +691,32 @@ iterator genReadMessageProc(msg: Message): string = yield indent("else: skipField(stream, wireType)", 8) yield "" +iterator genSizeOfMapKVProc(message: Message): string = + let + key = mapKeyField(message) + value = mapValueField(message) + + yield &"proc sizeOf{message.names}KV(key: {key.fullType}, value: {value.fullType}): uint64 =" + yield indent(&"result = result + sizeOf{key.typeName}(key)", 4) + yield indent(&"result = result + sizeOfUInt32(uint32(makeTag({key.number}, {key.wiretypeStr})))", 4) + yield indent(&"let valueSize = sizeOf{value.typeName}(value)", 4) + yield indent(&"result = result + valueSize", 4) + yield indent(&"result = result + sizeOfUInt32(uint32(makeTag({value.number}, {value.wiretypeStr})))", 4) + if isMessage(value): + yield indent(&"result = result + sizeOfUInt64(valueSize)", 4) + yield "" + iterator genSizeOfMessageProc(msg: Message): string = yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64 =" for field in msg.fields: - if isRepeated(field): + if isMapEntry(field): + yield indent(&"if has{field.name}(message):", 4) + yield indent(&"var sizeOfKV = 0'u64", 8) + yield indent(&"for key, value in message.{field.name}:", 8) + yield indent(&"sizeOfKV = sizeOfKV + sizeOf{field.typeName}KV(key, value)", 12) + yield indent(&"let sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, {wiretypeStr(field)})))", 8) + yield indent("result = result + sizeOfKV + sizeOfTag + sizeOfUInt64(sizeOfKV)", 8) + elif isRepeated(field): if isNumeric(field): yield indent(&""" if has{field.name}(message): @@ -617,11 +755,16 @@ iterator genProcs(msg: Message): string = for line in genHasFieldProc(msg, field): yield line for line in genSetFieldProc(msg, field): yield line - if isRepeated(field): + if isRepeated(field) and not isMapEntry(field): for line in genAddToFieldProc(msg, field): yield line for line in genFieldAccessorProcs(msg, field): yield line + if isMapEntry(msg): + for line in genSizeOfMapKVProc(msg): yield line + for line in genWriteMapKVProc(msg): yield line + for line in genReadMapKVProc(msg): yield line + for line in genSizeOfMessageProc(msg): yield line for line in genWriteMessageProc(msg): yield line for line in genReadMessageProc(msg): yield line @@ -652,9 +795,20 @@ proc processFile(filename: string, fdesc: FileDescriptorProto, result.name = pbfilename result.data = "" + let parsed = parseFile(filename, fdesc) + + var hasMaps = false + for message in parsed.messages: + let tmp = fixMapEntry(parsed, message) + if tmp: + hasMaps = true + addLine(result.data, "# Generated by protoc_gen_nim. Do not edit!") addLine(result.data, "") addLine(result.data, "import intsets") + if hasMaps: + addLine(result.data, "import tables") + addLine(result.data, "export tables") addLine(result.data, "") addLine(result.data, "import protobuf/stream") addLine(result.data, "import protobuf/types") @@ -668,8 +822,6 @@ proc processFile(filename: string, fdesc: FileDescriptorProto, if hasDependency(fdesc): addLine(result.data, "") - let parsed = parseFile(filename, fdesc) - addLine(result.data, "type") for e in parsed.enums: |
