diff options
| -rw-r--r-- | nimpb/compiler/generator.nim | 111 |
1 files changed, 82 insertions, 29 deletions
diff --git a/nimpb/compiler/generator.nim b/nimpb/compiler/generator.nim index beea943..baf9d6a 100644 --- a/nimpb/compiler/generator.nim +++ b/nimpb/compiler/generator.nim @@ -352,6 +352,17 @@ proc sizeOfProc(field: Field): string = else: result = &"sizeOf{field.typeName}" +proc shouldGenerateHasField(msg: Message, field: Field): bool = + if msg.file.syntax == Syntax.Proto2: + result = true + else: + if isMapEntry(field): + result = false + elif isMessage(field): + result = true + elif field.oneof != nil: + result = true + proc newField(file: ProtoFile, message: Message, desc: google_protobuf_FieldDescriptorProto): Field = new(result) @@ -662,11 +673,13 @@ iterator genType(message: Message): string = yield indent(&"{message.names}_{oneof.name}_Kind* {{.pure.}} = enum", 0) for field in oneof.fields: yield indent(&"{field.name.capitalizeAscii}", 4) + yield indent(&"NotSet", 4) yield "" yield &"{message.names}_{oneof.name}_OneOf* = object" yield indent(&"case kind*: {message.names}_{oneof.name}_Kind", 4) for field in oneof.fields: yield indent(&"of {message.names}_{oneof.name}_Kind.{field.name.capitalizeAscii}: {quoteReserved(field.name)}*: {field.fullType}", 4) + yield indent(&"of {message.names}_{oneof.name}_Kind.NotSet: nil", 4) iterator genNewMessageProc(msg: Message): string = yield &"proc new{msg.names}*(): {msg.names} =" @@ -675,6 +688,8 @@ iterator genNewMessageProc(msg: Message): string = for field in msg.fields: if field.oneof == nil: yield indent(&"result.{field.accessor} = {defaultValue(field)}", 4) + for oneof in msg.oneofs: + yield indent(&"result.{oneof.name}.kind = {msg.names}_{oneof.name}_Kind.NotSet", 4) yield "" iterator oneofSiblings(field: Field): Field = @@ -689,11 +704,14 @@ iterator genClearFieldProc(msg: Message, field: Field): string = if field.oneof == nil: yield indent(&"message.{field.accessor} = {defaultValue(field)}", 4) else: - yield indent(&"reset(message.{field.oneof.name})", 4) - var numbers: seq[int] = @[field.number] - for sibling in oneofSiblings(field): - add(numbers, sibling.number) - yield indent(&"clearFields(message, [{join(numbers, \", \")}])", 4) + let oneof = field.oneof + yield indent(&"reset(message.{oneof.name})", 4) + yield indent(&"message.{oneof.name}.kind = {msg.names}_{oneof.name}_Kind.NotSet", 4) + if shouldGenerateHasField(msg, field): + var numbers: seq[int] = @[field.number] + for sibling in oneofSiblings(field): + add(numbers, sibling.number) + yield indent(&"clearFields(message, [{join(numbers, \", \")}])", 4) yield "" iterator genHasFieldProc(msg: Message, field: Field): string = @@ -713,18 +731,18 @@ iterator genSetFieldProc(msg: Message, field: Field): string = yield indent(&"reset(message.{field.oneof.name})", 8) yield indent(&"message.{field.oneof.name}.kind = {msg.names}_{field.oneof.name}_Kind.{field.name.capitalizeAscii}", 8) yield indent(&"message.{field.accessor} = value", 4) - yield indent(&"setField(message, {field.number})", 4) - var numbers: seq[int] = @[] - for sibling in oneofSiblings(field): - add(numbers, sibling.number) - if len(numbers) > 0: - yield indent(&"clearFields(message, [{join(numbers, \", \")}])", 4) + if shouldGenerateHasField(msg, field): + yield indent(&"setField(message, {field.number})", 4) + var numbers: seq[int] = @[] + for sibling in oneofSiblings(field): + add(numbers, sibling.number) + if len(numbers) > 0: + yield indent(&"clearFields(message, [{join(numbers, \", \")}])", 4) yield "" iterator genAddToFieldProc(msg: Message, field: Field): string = yield &"proc add{field.name}*(message: {msg.names}, value: {field.nimTypeName}) =" yield indent(&"add(message.{field.name}, value)", 4) - yield indent(&"setField(message, {field.number})", 4) yield "" iterator genFieldAccessorProcs(msg: Message, field: Field): string = @@ -746,6 +764,43 @@ iterator genWriteMapKVProc(msg: Message): string = yield indent(&"{value.writeProc}(stream, value, {value.number})", 4) yield "" +proc hasFieldCheck(msg: string, field: Field): string = + if isRepeated(field) or isMapEntry(field): + return &"len({msg}.{field.accessor}) > 0" + elif field.oneof != nil: + # Oneof fields only check the kind of the oneof field. If we anded this + # check with a check from below, we couldn't convey to the deserializing + # side which oneof field was set. For example, if we are serializing an + # string, if we included the string check, we wouldn't serialize + # anything. This would make the receiving end think that the oneof field + # didn't have anything set even though on the sending side the string + # field was actually set. So by also serializing the default values on + # the wire, we give some information about the oneof field itself. + return &"{msg}.{field.oneof.name}.kind == {field.message.names}_{field.oneof.name}_Kind.{field.name.capitalizeAscii}" + + case field.ftype + of google_protobuf_FieldDescriptorProtoType.TypeDouble, + google_protobuf_FieldDescriptorProtoType.TypeFloat, + google_protobuf_FieldDescriptorProtoType.TypeInt64, + google_protobuf_FieldDescriptorProtoType.TypeUInt64, + google_protobuf_FieldDescriptorProtoType.TypeInt32, + google_protobuf_FieldDescriptorProtoType.TypeFixed64, + google_protobuf_FieldDescriptorProtoType.TypeFixed32, + google_protobuf_FieldDescriptorProtoType.TypeUInt32, + google_protobuf_FieldDescriptorProtoType.TypeSFixed32, + google_protobuf_FieldDescriptorProtoType.TypeSFixed64, + google_protobuf_FieldDescriptorProtoType.TypeSInt32, + google_protobuf_FieldDescriptorProtoType.TypeSInt64, + google_protobuf_FieldDescriptorProtoType.TypeBool, + google_protobuf_FieldDescriptorProtoType.TypeEnum: + result = &"{msg}.{field.accessor} != {defaultValue(field)}" + of google_protobuf_FieldDescriptorProtoType.TypeString, + google_protobuf_FieldDescriptorProtoType.TypeBytes: + result = &"len({msg}.{field.accessor}) > 0" + of google_protobuf_FieldDescriptorProtoType.TypeGroup: result = "" + of google_protobuf_FieldDescriptorProtoType.TypeMessage: + result = &"has{field.name}({msg})" + iterator genWriteMessageProc(msg: Message): string = yield &"proc write{msg.names}*(stream: Stream, message: {msg.names}) =" @@ -757,7 +812,8 @@ iterator genWriteMessageProc(msg: Message): string = yield indent(&"{field.writeProc}(stream, key, value)", 8) elif isRepeated(field): if field.packed: - yield indent(&"if has{field.name}(message):", 4) + let check = hasFieldCheck("message", field) + yield indent(&"if {check}:", 4) yield indent(&"writeTag(stream, {field.number}, WireType.LengthDelimited)", 8) yield indent(&"writeVarint(stream, packedFieldSize(message.{field.name}, {field.fieldTypeStr}))", 8) yield indent(&"for value in message.{field.name}:", 8) @@ -766,7 +822,8 @@ iterator genWriteMessageProc(msg: Message): string = yield indent(&"for value in message.{field.name}:", 4) yield indent(&"{field.writeProc}(stream, value, {field.number})", 8) else: - yield indent(&"if has{field.name}(message):", 4) + let check = hasFieldCheck("message", field) + yield indent(&"if {check}:", 4) yield indent(&"{field.writeProc}(stream, message.{field.accessor}, {field.number})", 8) yield indent("writeUnknownFields(stream, message)", 4) @@ -891,8 +948,9 @@ iterator genSizeOfMapKVProc(message: Message): string = iterator genSizeOfMessageProc(msg: Message): string = yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64 =" for field in msg.fields: + let check = hasFieldCheck("message", field) if isMapEntry(field): - yield indent(&"if has{field.name}(message):", 4) + yield indent(&"if {check}:", 4) yield indent(&"var sizeOfKV = 0'u64", 8) yield indent(&"for key, value in message.{field.name}:", 8) yield indent(&"sizeOfKV = sizeOfKV + {field.sizeOfProc}(key, value)", 12) @@ -900,7 +958,7 @@ iterator genSizeOfMessageProc(msg: Message): string = yield indent(&"result = result + sizeOfLengthDelimited(sizeOfKV)", 8) elif isRepeated(field): if isNumeric(field): - yield indent(&"if has{field.name}(message):", 4) + yield indent(&"if {check}:", 4) yield indent(&"result = result + sizeOfTag({field.number}, WireType.LengthDelimited)", 8) yield indent(&"result = result + sizeOfLengthDelimited(packedFieldSize(message.{field.name}, {field.fieldTypeStr}))", 8) else: @@ -911,7 +969,7 @@ iterator genSizeOfMessageProc(msg: Message): string = else: yield indent(&"result = result + {field.sizeOfProc}(value)", 8) else: - yield indent(&"if has{field.name}(message):", 4) + yield indent(&"if {check}:", 4) yield indent(&"result = result + sizeOfTag({field.number}, {field.wiretypeStr})", 8) if isMessage(field): yield indent(&"result = result + sizeOfLengthDelimited({field.sizeOfProc}(message.{field.accessor}))", 8) @@ -964,7 +1022,8 @@ iterator genMessageToJsonProc(msg: Message): string = result = &"%{v}" for field in msg.fields: - yield indent(&"if has{field.name}(message):", 4) + let check = hasFieldCheck("message", field) + yield indent(&"if {check}:", 4) if isMapEntry(field): yield indent("let obj = newJObject()", 8) yield indent(&"for key, value in message.{field.name}:", 8) @@ -1008,19 +1067,12 @@ iterator genMessageFromJsonProc(msg: Message): string = elif field.ftype == google_protobuf_FieldDescriptorProto_Type.TypeBytes: result = &"parseBytes({n})" - var oneOfsHandled: seq[string] = @[] - for field in msg.fields: - if field.oneof != nil: - if field.oneof.name notin oneOfsHandled: - yield indent(&"var {field.oneof.name}Done = false", 4) - add(oneOfsHandled, field.oneof.name) yield indent(&"node = getJsonField(obj, \"{field.protoName}\", \"{field.jsonName}\")", 4) yield indent(&"if node != nil and node.kind != JNull:", 4) if field.oneof != nil: - yield indent(&"if {field.oneof.name}Done:", 8) + yield indent(&"if result.{field.oneof.name}.kind != {field.message.names}_{field.oneof.name}_Kind.NotSet:", 8) yield indent(&"raise newException(nimpb_json.ParseError, \"multiple values for oneof encountered\")", 12) - yield indent(&"{field.oneof.name}Done = true", 8) if isMapEntry(field): yield indent("if node.kind != JObject:", 8) yield indent("raise newException(ValueError, \"not an object\")", 12) @@ -1059,7 +1111,7 @@ iterator genMessageProcForwards(msg: Message): string = yield &"proc write{msg.names}*(stream: Stream, message: {msg.names})" yield &"proc read{msg.names}*(stream: Stream): {msg.names}" yield &"proc sizeOf{msg.names}*(message: {msg.names}): uint64" - if shouldGenerateJsonProcs($msg.names): + if msg.file.syntax == Syntax.Proto3 and shouldGenerateJsonProcs($msg.names): yield &"proc toJson*(message: {msg.names}): JsonNode" yield &"proc parse{msg.names}*(obj: JsonNode): {msg.names}" else: @@ -1081,7 +1133,8 @@ iterator genProcs(msg: Message): string = for field in msg.fields: for line in genClearFieldProc(msg, field): yield line - for line in genHasFieldProc(msg, field): yield line + if shouldGenerateHasField(msg, field): + for line in genHasFieldProc(msg, field): yield line for line in genSetFieldProc(msg, field): yield line if isRepeated(field) and not isMapEntry(field): @@ -1093,7 +1146,7 @@ iterator genProcs(msg: Message): string = for line in genWriteMessageProc(msg): yield line for line in genReadMessageProc(msg): yield line - if shouldGenerateJsonProcs($msg.names): + if msg.file.syntax == Syntax.Proto3 and shouldGenerateJsonProcs($msg.names): for line in genMessageToJsonProc(msg): yield line for line in genMessageFromJsonProc(msg): yield line |
