aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nimpb/compiler/generator.nim111
1 files changed, 82 insertions, 29 deletions
diff --git a/nimpb/compiler/generator.nim b/nimpb/compiler/generator.nim
index beea943..baf9d6a 100644
--- a/nimpb/compiler/generator.nim
+++ b/nimpb/compiler/generator.nim
@@ -352,6 +352,17 @@ proc sizeOfProc(field: Field): string =
else:
result = &"sizeOf{field.typeName}"
+proc shouldGenerateHasField(msg: Message, field: Field): bool =
+ if msg.file.syntax == Syntax.Proto2:
+ result = true
+ else:
+ if isMapEntry(field):
+ result = false
+ elif isMessage(field):
+ result = true
+ elif field.oneof != nil:
+ result = true
+
proc newField(file: ProtoFile, message: Message, desc: google_protobuf_FieldDescriptorProto): Field =
new(result)
@@ -662,11 +673,13 @@ iterator genType(message: Message): string =
yield indent(&"{message.names}_{oneof.name}_Kind* {{.pure.}} = enum", 0)
for field in oneof.fields:
yield indent(&"{field.name.capitalizeAscii}", 4)
+ yield indent(&"NotSet", 4)
yield ""
yield &"{message.names}_{oneof.name}_OneOf* = object"
yield indent(&"case kind*: {message.names}_{oneof.name}_Kind", 4)
for field in oneof.fields:
yield indent(&"of {message.names}_{oneof.name}_Kind.{field.name.capitalizeAscii}: {quoteReserved(field.name)}*: {field.fullType}", 4)
+ yield indent(&"of {message.names}_{oneof.name}_Kind.NotSet: nil", 4)
iterator genNewMessageProc(msg: Message): string =
yield &"proc new{msg.names}*(): {msg.names} ="
@@ -675,6 +688,8 @@ iterator genNewMessageProc(msg: Message): string =
for field in msg.fields:
if field.oneof == nil:
yield indent(&"result.{field.accessor} = {defaultValue(field)}", 4)
+ for oneof in msg.oneofs:
+ yield indent(&"result.{oneof.name}.kind = {msg.names}_{oneof.name}_Kind.NotSet", 4)
yield ""
iterator oneofSiblings(field: Field): Field =
@@ -689,11 +704,14 @@ iterator genClearFieldProc(msg: Message, field: Field): string =
if field.oneof == nil:
yield indent(&"message.{field.accessor} = {defaultValue(field)}", 4)
else:
- yield indent(&"reset(message.{field.oneof.name})", 4)
- var numbers: seq[int] = @[field.number]
- for sibling in oneofSiblings(field):
- add(numbers, sibling.number)
- yield indent(&"clearFields(message, [{join(numbers, \", \")}])", 4)
+ let oneof = field.oneof
+ yield indent(&"reset(message.{oneof.name})", 4)
+ yield indent(&"message.{oneof.name}.kind = {msg.names}_{oneof.name}_Kind.NotSet", 4)
+ if shouldGenerateHasField(msg, field):
+ var numbers: seq[int] = @[field.number]
+ for sibling in oneofSiblings(field):
+ add(numbers, sibling.number)
+ yield indent(&"clearFields(message, [{join(numbers, \", \")}])", 4)
yield ""
iterator genHasFieldProc(msg: Message, field: Field): string =
@@ -713,18 +731,18 @@ iterator genSetFieldProc(msg: Message, field: Field): string =
yield indent(&"reset(message.{field.oneof.name})", 8)
yield indent(&"message.{field.oneof.name}.kind = {msg.names}_{field.oneof.name}_Kind.{field.name.capitalizeAscii}", 8)
yield indent(&"message.{field.accessor} = value", 4)
- yield indent(&"setField(message, {field.number})", 4)
- var numbers: seq[int] = @[]
- for sibling in oneofSiblings(field):
- add(numbers, sibling.number)
- if len(numbers) > 0:
- yield indent(&"clearFields(message, [{join(numbers, \", \")}])", 4)
+ if shouldGenerateHasField(msg, field):
+ yield indent(&"setField(message, {field.number})", 4)
+ var numbers: seq[int] = @[]
+ for sibling in oneofSiblings(field):
+ add(numbers, sibling.number)
+ if len(numbers) > 0:
+ yield indent(&"clearFields(message, [{join(numbers, \", \")}])", 4)
yield ""
iterator genAddToFieldProc(msg: Message, field: Field): string =
yield &"proc add{field.name}*(message: {msg.names}, value: {field.nimTypeName}) ="
yield indent(&"add(message.{field.name}, value)", 4)
- yield indent(&"setField(message, {field.number})", 4)
yield ""
iterator genFieldAccessorProcs(msg: Message, field: Field): string =
@@ -746,6 +764,43 @@ iterator genWriteMapKVProc(msg: Message): string =
yield indent(&"{value.writeProc}(stream, value, {value.number})", 4)
yield ""
+proc hasFieldCheck(msg: string, field: Field): string =
+ if isRepeated(field) or isMapEntry(field):
+ return &"len({msg}.{field.accessor}) > 0"
+ elif field.oneof != nil:
+ # Oneof fields only check the kind of the oneof field. If we anded this
+ # check with a check from below, we couldn't convey to the deserializing
+ # side which oneof field was set. For example, if we are serializing an
+ # string, if we included the string check, we wouldn't serialize
+ # anything. This would make the receiving end think that the oneof field
+ # didn't have anything set even though on the sending side the string
+ # field was actually set. So by also serializing the default values on
+ # the wire, we give some information about the oneof field itself.
+ return &"{msg}.{field.oneof.name}.kind == {field.message.names}_{field.oneof.name}_Kind.{field.name.capitalizeAscii}"
+
+ case field.ftype
+ of google_protobuf_FieldDescriptorProtoType.TypeDouble,
+ google_protobuf_FieldDescriptorProtoType.TypeFloat,
+ google_protobuf_FieldDescriptorProtoType.TypeInt64,
+ google_protobuf_FieldDescriptorProtoType.TypeUInt64,
+ google_protobuf_FieldDescriptorProtoType.TypeInt32,
+ google_protobuf_FieldDescriptorProtoType.TypeFixed64,
+ google_protobuf_FieldDescriptorProtoType.TypeFixed32,
+ google_protobuf_FieldDescriptorProtoType.TypeUInt32,
+ google_protobuf_FieldDescriptorProtoType.TypeSFixed32,
+ google_protobuf_FieldDescriptorProtoType.TypeSFixed64,
+ google_protobuf_FieldDescriptorProtoType.TypeSInt32,
+ google_protobuf_FieldDescriptorProtoType.TypeSInt64,
+ google_protobuf_FieldDescriptorProtoType.TypeBool,
+ google_protobuf_FieldDescriptorProtoType.TypeEnum:
+ result = &"{msg}.{field.accessor} != {defaultValue(field)}"
+ of google_protobuf_FieldDescriptorProtoType.TypeString,
+ google_protobuf_FieldDescriptorProtoType.TypeBytes:
+ result = &"len({msg}.{field.accessor}) > 0"
+ of google_protobuf_FieldDescriptorProtoType.TypeGroup: result = ""
+ of google_protobuf_FieldDescriptorProtoType.TypeMessage:
+ result = &"has{field.name}({msg})"
+
iterator genWriteMessageProc(msg: Message): string =
yield &"proc write{msg.names}*(stream: Stream, message: {msg.names}) ="
@@ -757,7 +812,8 @@ iterator genWriteMessageProc(msg: Message): string =
yield indent(&"{field.writeProc}(stream, key, value)", 8)
elif isRepeated(field):
if field.packed:
- yield indent(&"if has{field.name}(message):", 4)
+ let check = hasFieldCheck("message", field)
+ yield indent(&"if {check}:", 4)
yield indent(&"writeTag(stream, {field.number}, WireType.LengthDelimited)", 8)
yield indent(&"writeVarint(stream, packedFieldSize(message.{field.name}, {field.fieldTypeStr}))", 8)
yield indent(&"for value in message.{field.name}:", 8)
@@ -766,7 +822,8 @@ iterator genWriteMessageProc(msg: Message): string =
yield indent(&"for value in message.{field.name}:", 4)
yield indent(&"{field.writeProc}(stream, value, {field.number})", 8)
else:
- yield indent(&"if has{field.name}(message):", 4)
+ let check = hasFieldCheck("message", field)
+ yield indent(&"if {check}:", 4)
yield indent(&"{field.writeProc}(stream, message.{field.accessor}, {field.number})", 8)
yield indent("writeUnknownFields(stream, message)", 4)
@@ -891,8 +948,9 @@ iterator genSizeOfMapKVProc(message: Message): string =
iterator genSizeOfMessageProc(msg: Message): string =
yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64 ="
for field in msg.fields:
+ let check = hasFieldCheck("message", field)
if isMapEntry(field):
- yield indent(&"if has{field.name}(message):", 4)
+ yield indent(&"if {check}:", 4)
yield indent(&"var sizeOfKV = 0'u64", 8)
yield indent(&"for key, value in message.{field.name}:", 8)
yield indent(&"sizeOfKV = sizeOfKV + {field.sizeOfProc}(key, value)", 12)
@@ -900,7 +958,7 @@ iterator genSizeOfMessageProc(msg: Message): string =
yield indent(&"result = result + sizeOfLengthDelimited(sizeOfKV)", 8)
elif isRepeated(field):
if isNumeric(field):
- yield indent(&"if has{field.name}(message):", 4)
+ yield indent(&"if {check}:", 4)
yield indent(&"result = result + sizeOfTag({field.number}, WireType.LengthDelimited)", 8)
yield indent(&"result = result + sizeOfLengthDelimited(packedFieldSize(message.{field.name}, {field.fieldTypeStr}))", 8)
else:
@@ -911,7 +969,7 @@ iterator genSizeOfMessageProc(msg: Message): string =
else:
yield indent(&"result = result + {field.sizeOfProc}(value)", 8)
else:
- yield indent(&"if has{field.name}(message):", 4)
+ yield indent(&"if {check}:", 4)
yield indent(&"result = result + sizeOfTag({field.number}, {field.wiretypeStr})", 8)
if isMessage(field):
yield indent(&"result = result + sizeOfLengthDelimited({field.sizeOfProc}(message.{field.accessor}))", 8)
@@ -964,7 +1022,8 @@ iterator genMessageToJsonProc(msg: Message): string =
result = &"%{v}"
for field in msg.fields:
- yield indent(&"if has{field.name}(message):", 4)
+ let check = hasFieldCheck("message", field)
+ yield indent(&"if {check}:", 4)
if isMapEntry(field):
yield indent("let obj = newJObject()", 8)
yield indent(&"for key, value in message.{field.name}:", 8)
@@ -1008,19 +1067,12 @@ iterator genMessageFromJsonProc(msg: Message): string =
elif field.ftype == google_protobuf_FieldDescriptorProto_Type.TypeBytes:
result = &"parseBytes({n})"
- var oneOfsHandled: seq[string] = @[]
-
for field in msg.fields:
- if field.oneof != nil:
- if field.oneof.name notin oneOfsHandled:
- yield indent(&"var {field.oneof.name}Done = false", 4)
- add(oneOfsHandled, field.oneof.name)
yield indent(&"node = getJsonField(obj, \"{field.protoName}\", \"{field.jsonName}\")", 4)
yield indent(&"if node != nil and node.kind != JNull:", 4)
if field.oneof != nil:
- yield indent(&"if {field.oneof.name}Done:", 8)
+ yield indent(&"if result.{field.oneof.name}.kind != {field.message.names}_{field.oneof.name}_Kind.NotSet:", 8)
yield indent(&"raise newException(nimpb_json.ParseError, \"multiple values for oneof encountered\")", 12)
- yield indent(&"{field.oneof.name}Done = true", 8)
if isMapEntry(field):
yield indent("if node.kind != JObject:", 8)
yield indent("raise newException(ValueError, \"not an object\")", 12)
@@ -1059,7 +1111,7 @@ iterator genMessageProcForwards(msg: Message): string =
yield &"proc write{msg.names}*(stream: Stream, message: {msg.names})"
yield &"proc read{msg.names}*(stream: Stream): {msg.names}"
yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64"
- if shouldGenerateJsonProcs($msg.names):
+ if msg.file.syntax == Syntax.Proto3 and shouldGenerateJsonProcs($msg.names):
yield &"proc toJson*(message: {msg.names}): JsonNode"
yield &"proc parse{msg.names}*(obj: JsonNode): {msg.names}"
else:
@@ -1081,7 +1133,8 @@ iterator genProcs(msg: Message): string =
for field in msg.fields:
for line in genClearFieldProc(msg, field): yield line
- for line in genHasFieldProc(msg, field): yield line
+ if shouldGenerateHasField(msg, field):
+ for line in genHasFieldProc(msg, field): yield line
for line in genSetFieldProc(msg, field): yield line
if isRepeated(field) and not isMapEntry(field):
@@ -1093,7 +1146,7 @@ iterator genProcs(msg: Message): string =
for line in genWriteMessageProc(msg): yield line
for line in genReadMessageProc(msg): yield line
- if shouldGenerateJsonProcs($msg.names):
+ if msg.file.syntax == Syntax.Proto3 and shouldGenerateJsonProcs($msg.names):
for line in genMessageToJsonProc(msg): yield line
for line in genMessageFromJsonProc(msg): yield line