aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOskari Timperi <oskari.timperi@iki.fi>2018-04-01 16:01:55 +0300
committerOskari Timperi <oskari.timperi@iki.fi>2018-04-01 16:01:55 +0300
commit75ec5d643ce7a10b05bb18de3fc7a3e678183103 (patch)
tree7937e8d7ce67ea9d7df0d6d10cb70c9a13ca9e88
parent35864af115c17b721bfdc926e0e49ca26a3c431b (diff)
downloadnimpb-75ec5d643ce7a10b05bb18de3fc7a3e678183103.tar.gz
nimpb-75ec5d643ce7a10b05bb18de3fc7a3e678183103.zip
Add initial native support for maps
-rw-r--r--generator/descriptor_pb.nim27
-rw-r--r--generator/protoc_gen_nim.nim176
2 files changed, 191 insertions, 12 deletions
diff --git a/generator/descriptor_pb.nim b/generator/descriptor_pb.nim
index 05c21b9..890a61d 100644
--- a/generator/descriptor_pb.nim
+++ b/generator/descriptor_pb.nim
@@ -128,6 +128,15 @@ const
packed: false,
oneofIdx: -1,
),
+ FieldDesc(
+ name: "options",
+ number: 7,
+ ftype: FieldType.Message,
+ label: FieldLabel.Optional,
+ typeName: "MessageOptions",
+ packed: false,
+ oneofIdx: -1,
+ ),
]
)
@@ -281,6 +290,21 @@ const
]
)
+ MessageOptionsDesc = MessageDesc(
+ name: "MessageOptions",
+ fields: @[
+ FieldDesc(
+ name: "map_entry",
+ number: 7,
+ ftype: FieldType.Bool,
+ label: FieldLabel.Optional,
+ typeName: "",
+ packed: false,
+ oneofIdx: -1,
+ ),
+ ]
+ )
+
FieldOptionsDesc = MessageDesc(
name: "FieldOptions",
fields: @[
@@ -332,6 +356,9 @@ generateMessageProcs(FieldDescriptorProtoDesc)
generateMessageType(OneofDescriptorProtoDesc)
generateMessageProcs(OneofDescriptorProtoDesc)
+generateMessageType(MessageOptionsDesc)
+generateMessageProcs(MessageOptionsDesc)
+
generateMessageType(DescriptorProtoDesc)
generateMessageProcs(DescriptorProtoDesc)
diff --git a/generator/protoc_gen_nim.nim b/generator/protoc_gen_nim.nim
index ed7fac3..4a1e3df 100644
--- a/generator/protoc_gen_nim.nim
+++ b/generator/protoc_gen_nim.nim
@@ -28,11 +28,13 @@ type
typeName: string
packed: bool
oneof: Oneof
+ mapEntry: Message
Message = ref object
names: Names
fields: seq[Field]
oneofs: seq[Oneof]
+ mapEntry: bool
Oneof = ref object
name: string
@@ -97,6 +99,12 @@ proc isNumeric(field: Field): bool =
result = true
else: discard
+proc isMapEntry(message: Message): bool =
+ result = message.mapEntry
+
+proc isMapEntry(field: Field): bool =
+ result = field.mapEntry != nil
+
proc nimTypeName(field: Field): string =
case field.ftype
of FieldDescriptorProtoType.TypeDouble: result = "float64"
@@ -118,6 +126,16 @@ proc nimTypeName(field: Field): string =
of FieldDescriptorProtoType.TypeSInt32: result = "int32"
of FieldDescriptorProtoType.TypeSInt64: result = "int64"
+proc mapKeyType(field: Field): string =
+ for f in field.mapEntry.fields:
+ if f.name == "key":
+ return f.nimTypeName
+
+proc mapValueType(field: Field): string =
+ for f in field.mapEntry.fields:
+ if f.name == "value":
+ return f.nimTypeName
+
proc `$`(ft: FieldDescriptorProtoType): string =
case ft
of FieldDescriptorProtoType.TypeDouble: result = "Double"
@@ -140,7 +158,9 @@ proc `$`(ft: FieldDescriptorProtoType): string =
of FieldDescriptorProtoType.TypeSInt64: result = "SInt64"
proc defaultValue(field: Field): string =
- if isRepeated(field):
+ if isMapEntry(field):
+ return &"newTable[{field.mapKeyType}, {field.mapValueType}]()"
+ elif isRepeated(field):
return "@[]"
case field.ftype
@@ -209,11 +229,12 @@ proc newField(file: ProtoFile, message: Message, desc: FieldDescriptorProto): Fi
result.ftype = desc.type
result.typeName = ""
result.packed = false
+ result.mapEntry = nil
if isKeyword(result.name):
result.name = "f" & result.name
- if isNumeric(result):
+ if isRepeated(result) and isNumeric(result):
if hasOptions(desc):
if hasPacked(desc.options):
result.packed = desc.options.packed
@@ -252,6 +273,10 @@ proc newMessage(file: ProtoFile, names: Names, desc: DescriptorProto): Message =
result.names = names
result.fields = @[]
result.oneofs = @[]
+ result.mapEntry = false
+
+ if hasMapEntry(desc.options):
+ result.mapEntry = desc.options.mapEntry
log(&"newMessage {$result.names}")
@@ -261,6 +286,15 @@ proc newMessage(file: ProtoFile, names: Names, desc: DescriptorProto): Message =
for field in desc.field:
add(result.fields, newField(file, result, field))
+proc fixMapEntry(file: ProtoFile, message: Message): bool =
+ for field in message.fields:
+ for msg in file.messages:
+ if $msg.names == field.typeName:
+ if msg.mapEntry:
+ log(&"fixing map {field.name} {msg.names}")
+ field.mapEntry = msg
+ result = true
+
proc newEnum(names: Names, desc: EnumDescriptorProto): Enum =
new(result)
@@ -416,9 +450,22 @@ iterator genType(e: Enum): string =
yield indent(&"{name} = {number}", 4)
proc fullType(field: Field): string =
- result = field.nimTypeName
- if isRepeated(field):
- result = &"seq[{result}]"
+ if isMapEntry(field):
+ result = &"TableRef[{field.mapKeyType}, {field.mapValueType}]"
+ else:
+ result = field.nimTypeName
+ if isRepeated(field):
+ result = &"seq[{result}]"
+
+proc mapKeyField(message: Message): Field =
+ for field in message.fields:
+ if field.name == "key":
+ return field
+
+proc mapValueField(message: Message): Field =
+ for field in message.fields:
+ if field.name == "value":
+ return field
iterator genType(message: Message): string =
yield &"{message.names}* = ref {message.names}Obj"
@@ -426,7 +473,9 @@ iterator genType(message: Message): string =
yield indent(&"hasField: IntSet", 4)
for field in message.fields:
- if field.oneof == nil:
+ if isMapEntry(field):
+ yield indent(&"{field.name}: TableRef[{mapKeyType(field)}, {mapValueType(field)}]", 4)
+ elif field.oneof == nil:
yield indent(&"{quoteReserved(field.name)}: {field.fullType}", 4)
for oneof in message.oneofs:
@@ -472,6 +521,7 @@ iterator genClearFieldProc(msg: Message, field: Field): string =
yield ""
iterator genHasFieldProc(msg: Message, field: Field): string =
+ # TODO: if map/seq, check also if there are values!
yield &"proc has{field.name}*(message: {msg.names}): bool ="
yield indent(&"result = contains(message.hasField, {field.number})", 4)
yield ""
@@ -499,11 +549,33 @@ iterator genFieldAccessorProcs(msg: Message, field: Field): string =
yield indent(&"set{field.name}(message, value)", 4)
yield ""
+iterator genWriteMapKVProc(msg: Message): string =
+ let
+ key = mapKeyField(msg)
+ value = mapValueField(msg)
+
+ yield &"proc write{msg.names}KV(stream: ProtobufStream, key: {key.fullType}, value: {value.fullType}) ="
+
+ yield indent(&"writeTag(stream, {key.number}, {wiretypeStr(key)})", 4)
+ yield indent(&"write{key.typeName}(stream, key)", 4)
+
+ yield indent(&"writeTag(stream, {value.number}, {wiretypeStr(value)})", 4)
+ if isMessage(value):
+ yield indent(&"writeVarint(stream, sizeOf{value.typeName}(value))", 4)
+ yield indent(&"write{value.typeName}(stream, value)", 4)
+
+ yield ""
+
iterator genWriteMessageProc(msg: Message): string =
yield &"proc write{msg.names}*(stream: ProtobufStream, message: {msg.names}) ="
for field in msg.fields:
let writer = "write" & field.typeName
- if isRepeated(field):
+ if isMapEntry(field):
+ yield indent(&"for key, value in message.{field.name}:", 4)
+ yield indent(&"writeTag(stream, {field.number}, {wiretypeStr(field)})", 8)
+ yield indent(&"writeVarint(stream, sizeOf{field.typeName}KV(key, value))", 8)
+ yield indent(&"write{field.typeName}KV(stream, key, value)", 8)
+ elif isRepeated(field):
if field.packed:
yield indent(&"if has{field.name}(message):", 4)
yield indent(&"writeTag(stream, {field.number}, WireType.LengthDelimited)", 8)
@@ -524,6 +596,44 @@ iterator genWriteMessageProc(msg: Message): string =
yield indent(&"{writer}(stream, message.{field.accessor})", 8)
yield ""
+iterator genReadMapKVProc(msg: Message): string =
+ let
+ key = mapKeyField(msg)
+ value = mapValueField(msg)
+
+ yield &"proc read{msg.names}KV(stream: ProtobufStream, tbl: TableRef[{key.fullType}, {value.fullType}]) ="
+
+ yield indent(&"var", 4)
+ yield indent(&"key: {key.fullType}", 8)
+ yield indent("gotKey = false", 8)
+ yield indent(&"value: {value.fullType}", 8)
+ yield indent("gotValue = false", 8)
+ yield indent("while not atEnd(stream):", 4)
+ yield indent("let", 8)
+ yield indent("tag = readTag(stream)", 12)
+ yield indent("wireType = getTagWireType(tag)", 12)
+ yield indent("case getTagFieldNumber(tag)", 8)
+ yield indent(&"of {key.number}:", 8)
+ yield indent(&"key = read{key.typeName}(stream)", 12)
+ yield indent("gotKey = true", 12)
+ yield indent(&"of {value.number}:", 8)
+ if isMessage(value):
+ yield indent("let", 12)
+ yield indent("size = readVarint(stream)", 16)
+ yield indent("data = readStr(stream, int(size))", 16)
+ yield indent("pbs = newProtobufStream(newStringStream(data))", 16)
+ yield indent(&"value = read{value.typeName}(pbs)", 12)
+ else:
+ yield indent(&"value = read{value.typeName}(stream)", 12)
+ yield indent("gotValue = true", 12)
+ yield indent("else: skipField(stream, wireType)", 8)
+ yield indent("if not gotKey:", 4)
+ yield indent(&"raise newException(Exception, \"missing key ({msg.names})\")", 8)
+ yield indent("if not gotValue:", 4)
+ yield indent(&"raise newException(Exception, \"missing value ({msg.names})\")", 8)
+ yield indent("tbl[key] = value", 4)
+ yield ""
+
iterator genReadMessageProc(msg: Message): string =
yield &"proc read{msg.names}*(stream: ProtobufStream): {msg.names} ="
yield indent(&"result = new{msg.names}()", 4)
@@ -542,7 +652,13 @@ iterator genReadMessageProc(msg: Message): string =
&"set{field.name}"
yield indent(&"of {field.number}:", 8)
if isRepeated(field):
- if isNumeric(field):
+ if isMapEntry(field):
+ yield indent("let", 12)
+ yield indent("size = readVarint(stream)", 16)
+ yield indent("data = readStr(stream, int(size))", 16)
+ yield indent("pbs = newProtobufStream(newStringStream(data))", 16)
+ yield indent(&"read{field.typeName}KV(pbs, result.{field.name})", 12)
+ elif isNumeric(field):
yield indent("if wireType == WireType.LengthDelimited:", 12)
yield indent("let", 16)
yield indent("size = readVarint(stream)", 20)
@@ -575,10 +691,32 @@ iterator genReadMessageProc(msg: Message): string =
yield indent("else: skipField(stream, wireType)", 8)
yield ""
+iterator genSizeOfMapKVProc(message: Message): string =
+ let
+ key = mapKeyField(message)
+ value = mapValueField(message)
+
+ yield &"proc sizeOf{message.names}KV(key: {key.fullType}, value: {value.fullType}): uint64 ="
+ yield indent(&"result = result + sizeOf{key.typeName}(key)", 4)
+ yield indent(&"result = result + sizeOfUInt32(uint32(makeTag({key.number}, {key.wiretypeStr})))", 4)
+ yield indent(&"let valueSize = sizeOf{value.typeName}(value)", 4)
+ yield indent(&"result = result + valueSize", 4)
+ yield indent(&"result = result + sizeOfUInt32(uint32(makeTag({value.number}, {value.wiretypeStr})))", 4)
+ if isMessage(value):
+ yield indent(&"result = result + sizeOfUInt64(valueSize)", 4)
+ yield ""
+
iterator genSizeOfMessageProc(msg: Message): string =
yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64 ="
for field in msg.fields:
- if isRepeated(field):
+ if isMapEntry(field):
+ yield indent(&"if has{field.name}(message):", 4)
+ yield indent(&"var sizeOfKV = 0'u64", 8)
+ yield indent(&"for key, value in message.{field.name}:", 8)
+ yield indent(&"sizeOfKV = sizeOfKV + sizeOf{field.typeName}KV(key, value)", 12)
+ yield indent(&"let sizeOfTag = sizeOfUInt32(uint32(makeTag({field.number}, {wiretypeStr(field)})))", 8)
+ yield indent("result = result + sizeOfKV + sizeOfTag + sizeOfUInt64(sizeOfKV)", 8)
+ elif isRepeated(field):
if isNumeric(field):
yield indent(&"""
if has{field.name}(message):
@@ -617,11 +755,16 @@ iterator genProcs(msg: Message): string =
for line in genHasFieldProc(msg, field): yield line
for line in genSetFieldProc(msg, field): yield line
- if isRepeated(field):
+ if isRepeated(field) and not isMapEntry(field):
for line in genAddToFieldProc(msg, field): yield line
for line in genFieldAccessorProcs(msg, field): yield line
+ if isMapEntry(msg):
+ for line in genSizeOfMapKVProc(msg): yield line
+ for line in genWriteMapKVProc(msg): yield line
+ for line in genReadMapKVProc(msg): yield line
+
for line in genSizeOfMessageProc(msg): yield line
for line in genWriteMessageProc(msg): yield line
for line in genReadMessageProc(msg): yield line
@@ -652,9 +795,20 @@ proc processFile(filename: string, fdesc: FileDescriptorProto,
result.name = pbfilename
result.data = ""
+ let parsed = parseFile(filename, fdesc)
+
+ var hasMaps = false
+ for message in parsed.messages:
+ let tmp = fixMapEntry(parsed, message)
+ if tmp:
+ hasMaps = true
+
addLine(result.data, "# Generated by protoc_gen_nim. Do not edit!")
addLine(result.data, "")
addLine(result.data, "import intsets")
+ if hasMaps:
+ addLine(result.data, "import tables")
+ addLine(result.data, "export tables")
addLine(result.data, "")
addLine(result.data, "import protobuf/stream")
addLine(result.data, "import protobuf/types")
@@ -668,8 +822,6 @@ proc processFile(filename: string, fdesc: FileDescriptorProto,
if hasDependency(fdesc):
addLine(result.data, "")
- let parsed = parseFile(filename, fdesc)
-
addLine(result.data, "type")
for e in parsed.enums: