hack: generate vtproto files for buildx

Integrates vtproto into buildx. The generated files dockerfile has been
modified to copy the buildkit equivalent file to ensure files are laid
out in the appropriate way for imports.

An import has also been included to change the grpc codec to the version
in buildkit that supports vtproto. This will allow buildx to utilize the
speed and memory improvements from that.

Also updates the gc control options for prune.

Signed-off-by: Jonathan A. Sternberg <jonathan.sternberg@docker.com>
This commit is contained in:
Jonathan A. Sternberg
2024-10-08 13:35:06 -05:00
parent d353f5f6ba
commit 64c5139ab6
109 changed files with 68070 additions and 2941 deletions

View File

@ -0,0 +1,338 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package clone
import (
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
const (
cloneName = "CloneVT"
cloneMessageName = "CloneMessageVT"
)
var (
protoPkg = protogen.GoImportPath("google.golang.org/protobuf/proto")
)
func init() {
generator.RegisterFeature("clone", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &clone{GeneratedFile: gen}
})
}
type clone struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*clone)(nil)
func (p *clone) Name() string {
return "clone"
}
func (p *clone) GenerateFile(file *protogen.File) bool {
proto3 := file.Desc.Syntax() == protoreflect.Proto3
for _, message := range file.Messages {
p.processMessage(proto3, message)
}
return p.once
}
// cloneOneofField generates the statements for cloning a oneof field
func (p *clone) cloneOneofField(lhsBase, rhsBase string, oneof *protogen.Oneof) {
fieldname := oneof.GoName
ccInterfaceName := "is" + oneof.GoIdent.GoName
lhs := lhsBase + "." + fieldname
rhs := rhsBase + "." + fieldname
p.P(`if `, rhs, ` != nil {`)
if p.IsWellKnownType(oneof.Parent) {
p.P(`switch c := `, rhs, `.(type) {`)
for _, f := range oneof.Fields {
p.P(`case *`, f.GoIdent, `:`)
p.P(lhs, `= (*`, f.GoIdent, `)((*`, p.WellKnownFieldMap(f), `)(c).`, cloneName, `())`)
}
p.P(`}`)
} else {
p.P(lhs, ` = `, rhs, `.(interface{ `, cloneName, `() `, ccInterfaceName, ` }).`, cloneName, `()`)
}
p.P(`}`)
}
// cloneFieldSingular generates the code for cloning a singular, non-oneof field.
func (p *clone) cloneFieldSingular(lhs, rhs string, kind protoreflect.Kind, message *protogen.Message) {
switch {
case kind == protoreflect.MessageKind, kind == protoreflect.GroupKind:
switch {
case p.IsWellKnownType(message):
p.P(lhs, ` = (*`, message.GoIdent, `)((*`, p.WellKnownTypeMap(message), `)(`, rhs, `).`, cloneName, `())`)
case p.IsLocalMessage(message):
p.P(lhs, ` = `, rhs, `.`, cloneName, `()`)
default:
// rhs is a concrete type, we need to first convert it to an interface in order to use an interface
// type assertion.
p.P(`if vtpb, ok := interface{}(`, rhs, `).(interface{ `, cloneName, `() *`, message.GoIdent, ` }); ok {`)
p.P(lhs, ` = vtpb.`, cloneName, `()`)
p.P(`} else {`)
p.P(lhs, ` = `, protoPkg.Ident("Clone"), `(`, rhs, `).(*`, message.GoIdent, `)`)
p.P(`}`)
}
case kind == protoreflect.BytesKind:
p.P(`tmpBytes := make([]byte, len(`, rhs, `))`)
p.P(`copy(tmpBytes, `, rhs, `)`)
p.P(lhs, ` = tmpBytes`)
case isScalar(kind):
p.P(lhs, ` = `, rhs)
default:
panic("unexpected")
}
}
// cloneField generates the code for cloning a field in a protobuf.
func (p *clone) cloneField(lhsBase, rhsBase string, allFieldsNullable bool, field *protogen.Field) {
// At this point, if we encounter a non-synthetic oneof, we assume it to be the representative
// field for that oneof.
if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
p.cloneOneofField(lhsBase, rhsBase, field.Oneof)
return
}
if !isReference(allFieldsNullable, field) {
panic("method should not be invoked for non-reference fields")
}
fieldname := field.GoName
lhs := lhsBase + "." + fieldname
rhs := rhsBase + "." + fieldname
// At this point, we are only looking at reference types (pointers, maps, slices, interfaces), which can all
// be nil.
p.P(`if rhs := `, rhs, `; rhs != nil {`)
rhs = "rhs"
fieldKind := field.Desc.Kind()
msg := field.Message // possibly nil
if field.Desc.Cardinality() == protoreflect.Repeated { // maps and slices
goType, _ := p.FieldGoType(field)
p.P(`tmpContainer := make(`, goType, `, len(`, rhs, `))`)
if isScalar(fieldKind) && field.Desc.IsList() {
// Generated code optimization: instead of iterating over all (key/index, value) pairs,
// do a single copy(dst, src) invocation for slices whose elements aren't reference types.
p.P(`copy(tmpContainer, `, rhs, `)`)
} else {
if field.Desc.IsMap() {
// For maps, the type of the value field determines what code is generated for cloning
// an entry.
valueField := field.Message.Fields[1]
fieldKind = valueField.Desc.Kind()
msg = valueField.Message
}
p.P(`for k, v := range `, rhs, ` {`)
p.cloneFieldSingular("tmpContainer[k]", "v", fieldKind, msg)
p.P(`}`)
}
p.P(lhs, ` = tmpContainer`)
} else if isScalar(fieldKind) {
p.P(`tmpVal := *`, rhs)
p.P(lhs, ` = &tmpVal`)
} else {
p.cloneFieldSingular(lhs, rhs, fieldKind, msg)
}
p.P(`}`)
}
func (p *clone) generateCloneMethodsForMessage(proto3 bool, message *protogen.Message) {
ccTypeName := message.GoIdent.GoName
p.P(`func (m *`, ccTypeName, `) `, cloneName, `() *`, ccTypeName, ` {`)
p.body(!proto3, ccTypeName, message, true)
p.P(`}`)
p.P()
if !p.Wrapper() {
p.P(`func (m *`, ccTypeName, `) `, cloneMessageName, `() `, protoPkg.Ident("Message"), ` {`)
p.P(`return m.`, cloneName, `()`)
p.P(`}`)
p.P()
}
}
// body generates the code for the actual cloning logic of a structure containing the given fields.
// In practice, those can be the fields of a message.
// The object to be cloned is assumed to be called "m".
func (p *clone) body(allFieldsNullable bool, ccTypeName string, message *protogen.Message, cloneUnknownFields bool) {
// The method body for a message or a oneof wrapper always starts with a nil check.
p.P(`if m == nil {`)
// We use an explicitly typed nil to avoid returning the nil interface in the oneof wrapper
// case.
p.P(`return (*`, ccTypeName, `)(nil)`)
p.P(`}`)
fields := message.Fields
// Make a first pass over the fields, in which we initialize all non-reference fields via direct
// struct literal initialization, and extract all other (reference) fields for a second pass.
// Do not require qualified name because CloneVT generates in same file with definition.
p.Alloc("r", message, false)
var refFields []*protogen.Field
oneofFields := make(map[string]struct{}, len(fields))
for _, field := range fields {
if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
// Use the first field in a oneof as the representative for that oneof, disregard
// the other fields in that oneof.
if _, ok := oneofFields[field.Oneof.GoName]; !ok {
refFields = append(refFields, field)
oneofFields[field.Oneof.GoName] = struct{}{}
}
continue
}
if !isReference(allFieldsNullable, field) {
p.P(`r.`, field.GoName, ` = m.`, field.GoName)
continue
}
// Shortcut: for types where we know that an optimized clone method exists, we can call it directly as it is
// nil-safe.
if field.Desc.Cardinality() != protoreflect.Repeated {
switch {
case p.IsWellKnownType(field.Message):
p.P(`r.`, field.GoName, ` = (*`, field.Message.GoIdent, `)((*`, p.WellKnownTypeMap(field.Message), `)(m.`, field.GoName, `).`, cloneName, `())`)
continue
case p.IsLocalMessage(field.Message):
p.P(`r.`, field.GoName, ` = m.`, field.GoName, `.`, cloneName, `()`)
continue
}
}
refFields = append(refFields, field)
}
// Generate explicit assignment statements for all reference fields.
for _, field := range refFields {
p.cloneField("r", "m", allFieldsNullable, field)
}
if cloneUnknownFields && !p.Wrapper() {
// Clone unknown fields, if any
p.P(`if len(m.unknownFields) > 0 {`)
p.P(`r.unknownFields = make([]byte, len(m.unknownFields))`)
p.P(`copy(r.unknownFields, m.unknownFields)`)
p.P(`}`)
}
p.P(`return r`)
}
func (p *clone) bodyForOneOf(ccTypeName string, field *protogen.Field) {
// The method body for a message or a oneof wrapper always starts with a nil check.
p.P(`if m == nil {`)
// We use an explicitly typed nil to avoid returning the nil interface in the oneof wrapper
// case.
p.P(`return (*`, ccTypeName, `)(nil)`)
p.P(`}`)
p.P("r", " := new(", ccTypeName, `)`)
if !isReference(false, field) {
p.P(`r.`, field.GoName, ` = m.`, field.GoName)
p.P(`return r`)
return
}
// Shortcut: for types where we know that an optimized clone method exists, we can call it directly as it is
// nil-safe.
if field.Desc.Cardinality() != protoreflect.Repeated && field.Message != nil {
switch {
case p.IsWellKnownType(field.Message):
p.P(`r.`, field.GoName, ` = (*`, field.Message.GoIdent, `)((*`, p.WellKnownTypeMap(field.Message), `)(m.`, field.GoName, `).`, cloneName, `())`)
p.P(`return r`)
return
case p.IsLocalMessage(field.Message):
p.P(`r.`, field.GoName, ` = m.`, field.GoName, `.`, cloneName, `()`)
p.P(`return r`)
return
}
}
// Generate explicit assignment statements for reference field.
p.cloneField("r", "m", false, field)
p.P(`return r`)
}
// generateCloneMethodsForOneof generates the clone method for the oneof wrapper type of a
// field in a oneof.
func (p *clone) generateCloneMethodsForOneof(message *protogen.Message, field *protogen.Field) {
ccTypeName := field.GoIdent.GoName
ccInterfaceName := "is" + field.Oneof.GoIdent.GoName
if p.IsWellKnownType(message) {
p.P(`func (m *`, ccTypeName, `) `, cloneName, `() *`, ccTypeName, ` {`)
} else {
p.P(`func (m *`, ccTypeName, `) `, cloneName, `() `, ccInterfaceName, ` {`)
}
// Create a "fake" field for the single oneof member, pretending it is not a oneof field.
fieldInOneof := *field
fieldInOneof.Oneof = nil
// If we have a scalar field in a oneof, that field is never nullable, even when using proto2
p.bodyForOneOf(ccTypeName, &fieldInOneof)
p.P(`}`)
p.P()
}
func (p *clone) processMessageOneofs(message *protogen.Message) {
for _, field := range message.Fields {
if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() {
continue
}
p.generateCloneMethodsForOneof(message, field)
}
}
func (p *clone) processMessage(proto3 bool, message *protogen.Message) {
for _, nested := range message.Messages {
p.processMessage(proto3, nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
p.generateCloneMethodsForMessage(proto3, message)
p.processMessageOneofs(message)
}
// isReference checks whether the Go equivalent of the given field is of reference type, i.e., can be nil.
func isReference(allFieldsNullable bool, field *protogen.Field) bool {
if allFieldsNullable || field.Oneof != nil || field.Message != nil || field.Desc.Cardinality() == protoreflect.Repeated || field.Desc.Kind() == protoreflect.BytesKind {
return true
}
if !isScalar(field.Desc.Kind()) {
panic("unexpected non-reference, non-scalar field")
}
return false
}
func isScalar(kind protoreflect.Kind) bool {
switch kind {
case
protoreflect.BoolKind,
protoreflect.StringKind,
protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind,
protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind,
protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Sint64Kind,
protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.Sint32Kind,
protoreflect.EnumKind:
return true
}
return false
}

View File

@ -0,0 +1,308 @@
// Copyright (c) 2022 PlanetScale Inc. All rights reserved.
package equal
import (
"fmt"
"sort"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
func init() {
generator.RegisterFeature("equal", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &equal{GeneratedFile: gen}
})
}
var (
protoPkg = protogen.GoImportPath("google.golang.org/protobuf/proto")
)
type equal struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*equal)(nil)
func (p *equal) Name() string { return "equal" }
func (p *equal) GenerateFile(file *protogen.File) bool {
proto3 := file.Desc.Syntax() == protoreflect.Proto3
for _, message := range file.Messages {
p.message(proto3, message)
}
return p.once
}
const equalName = "EqualVT"
const equalMessageName = "EqualMessageVT"
func (p *equal) message(proto3 bool, message *protogen.Message) {
for _, nested := range message.Messages {
p.message(proto3, nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
ccTypeName := message.GoIdent.GoName
p.P(`func (this *`, ccTypeName, `) `, equalName, `(that *`, ccTypeName, `) bool {`)
p.P(`if this == that {`)
p.P(` return true`)
p.P(`} else if this == nil || that == nil {`)
p.P(` return false`)
p.P(`}`)
sort.Slice(message.Fields, func(i, j int) bool {
return message.Fields[i].Desc.Number() < message.Fields[j].Desc.Number()
})
{
oneofs := make(map[string]struct{}, len(message.Fields))
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
continue
}
fieldname := field.Oneof.GoName
if _, ok := oneofs[fieldname]; ok {
continue
}
oneofs[fieldname] = struct{}{}
p.P(`if this.`, fieldname, ` == nil && that.`, fieldname, ` != nil {`)
p.P(` return false`)
p.P(`} else if this.`, fieldname, ` != nil {`)
p.P(` if that.`, fieldname, ` == nil {`)
p.P(` return false`)
p.P(` }`)
ccInterfaceName := fmt.Sprintf("is%s", field.Oneof.GoIdent.GoName)
if p.IsWellKnownType(message) {
p.P(`switch c := this.`, fieldname, `.(type) {`)
for _, f := range field.Oneof.Fields {
p.P(`case *`, f.GoIdent, `:`)
p.P(`if !(*`, p.WellKnownFieldMap(f), `)(c).`, equalName, `(that.`, fieldname, `) {`)
p.P(`return false`)
p.P(`}`)
}
p.P(`}`)
} else {
p.P(`if !this.`, fieldname, `.(interface{ `, equalName, `(`, ccInterfaceName, `) bool }).`, equalName, `(that.`, fieldname, `) {`)
p.P(`return false`)
p.P(`}`)
}
p.P(`}`)
}
}
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
nullable := field.Message != nil || (field.Oneof != nil && field.Oneof.Desc.IsSynthetic()) || (!proto3 && !oneof)
if !oneof {
p.field(field, nullable)
}
}
if p.Wrapper() {
p.P(`return true`)
} else {
p.P(`return string(this.unknownFields) == string(that.unknownFields)`)
}
p.P(`}`)
p.P()
if !p.Wrapper() {
p.P(`func (this *`, ccTypeName, `) `, equalMessageName, `(thatMsg `, protoPkg.Ident("Message"), `) bool {`)
p.P(`that, ok := thatMsg.(*`, ccTypeName, `)`)
p.P(`if !ok {`)
p.P(`return false`)
p.P(`}`)
p.P(`return this.`, equalName, `(that)`)
p.P(`}`)
}
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
continue
}
p.oneof(field)
}
}
func (p *equal) oneof(field *protogen.Field) {
ccTypeName := field.GoIdent.GoName
ccInterfaceName := fmt.Sprintf("is%s", field.Oneof.GoIdent.GoName)
fieldname := field.GoName
if p.IsWellKnownType(field.Parent) {
p.P(`func (this *`, ccTypeName, `) `, equalName, `(thatIface any) bool {`)
} else {
p.P(`func (this *`, ccTypeName, `) `, equalName, `(thatIface `, ccInterfaceName, `) bool {`)
}
p.P(`that, ok := thatIface.(*`, ccTypeName, `)`)
p.P(`if !ok {`)
if p.IsWellKnownType(field.Parent) {
p.P(`if ot, ok := thatIface.(*`, field.GoIdent, `); ok {`)
p.P(`that = (*`, ccTypeName, `)(ot)`)
p.P("} else {")
p.P("return false")
p.P("}")
} else {
p.P(`return false`)
}
p.P(`}`)
p.P(`if this == that {`)
p.P(`return true`)
p.P(`}`)
p.P(`if this == nil && that != nil || this != nil && that == nil {`)
p.P(`return false`)
p.P(`}`)
lhs := fmt.Sprintf("this.%s", fieldname)
rhs := fmt.Sprintf("that.%s", fieldname)
kind := field.Desc.Kind()
switch {
case isScalar(kind):
p.compareScalar(lhs, rhs, false)
case kind == protoreflect.BytesKind:
p.compareBytes(lhs, rhs, false)
case kind == protoreflect.MessageKind || kind == protoreflect.GroupKind:
p.compareCall(lhs, rhs, field.Message, false)
default:
panic("not implemented")
}
p.P(`return true`)
p.P(`}`)
p.P()
}
func (p *equal) field(field *protogen.Field, nullable bool) {
fieldname := field.GoName
repeated := field.Desc.Cardinality() == protoreflect.Repeated
lhs := fmt.Sprintf("this.%s", fieldname)
rhs := fmt.Sprintf("that.%s", fieldname)
if repeated {
p.P(`if len(`, lhs, `) != len(`, rhs, `) {`)
p.P(` return false`)
p.P(`}`)
p.P(`for i, vx := range `, lhs, ` {`)
if field.Desc.IsMap() {
p.P(`vy, ok := `, rhs, `[i]`)
p.P(`if !ok {`)
p.P(`return false`)
p.P(`}`)
field = field.Message.Fields[1]
} else {
p.P(`vy := `, rhs, `[i]`)
}
lhs, rhs = "vx", "vy"
nullable = false
}
kind := field.Desc.Kind()
switch {
case isScalar(kind):
p.compareScalar(lhs, rhs, nullable)
case kind == protoreflect.BytesKind:
p.compareBytes(lhs, rhs, nullable)
case kind == protoreflect.MessageKind || kind == protoreflect.GroupKind:
p.compareCall(lhs, rhs, field.Message, nullable)
default:
panic("not implemented")
}
if repeated {
// close for loop
p.P(`}`)
}
}
func (p *equal) compareScalar(lhs, rhs string, nullable bool) {
if nullable {
p.P(`if p, q := `, lhs, `, `, rhs, `; (p == nil && q != nil) || (p != nil && (q == nil || *p != *q)) {`)
} else {
p.P(`if `, lhs, ` != `, rhs, ` {`)
}
p.P(` return false`)
p.P(`}`)
}
func (p *equal) compareBytes(lhs, rhs string, nullable bool) {
if nullable {
p.P(`if p, q := `, lhs, `, `, rhs, `; (p == nil && q != nil) || (p != nil && q == nil) || string(p) != string(q) {`)
} else {
// Inlined call to bytes.Equal()
p.P(`if string(`, lhs, `) != string(`, rhs, `) {`)
}
p.P(` return false`)
p.P(`}`)
}
func (p *equal) compareCall(lhs, rhs string, msg *protogen.Message, nullable bool) {
if !nullable {
// The p != q check is mostly intended to catch the lhs = nil, rhs = nil case in which we would pointlessly
// allocate not just one but two empty values. However, it also provides us with an extra scope to establish
// our p and q variables.
p.P(`if p, q := `, lhs, `, `, rhs, `; p != q {`)
defer p.P(`}`)
p.P(`if p == nil {`)
p.P(`p = &`, p.QualifiedGoIdent(msg.GoIdent), `{}`)
p.P(`}`)
p.P(`if q == nil {`)
p.P(`q = &`, p.QualifiedGoIdent(msg.GoIdent), `{}`)
p.P(`}`)
lhs, rhs = "p", "q"
}
switch {
case p.IsWellKnownType(msg):
wkt := p.WellKnownTypeMap(msg)
p.P(`if !(*`, wkt, `)(`, lhs, `).`, equalName, `((*`, wkt, `)(`, rhs, `)) {`)
p.P(` return false`)
p.P(`}`)
case p.IsLocalMessage(msg):
p.P(`if !`, lhs, `.`, equalName, `(`, rhs, `) {`)
p.P(` return false`)
p.P(`}`)
default:
p.P(`if equal, ok := interface{}(`, lhs, `).(interface { `, equalName, `(*`, p.QualifiedGoIdent(msg.GoIdent), `) bool }); ok {`)
p.P(` if !equal.`, equalName, `(`, rhs, `) {`)
p.P(` return false`)
p.P(` }`)
p.P(`} else if !`, p.Ident("google.golang.org/protobuf/proto", "Equal"), `(`, lhs, `, `, rhs, `) {`)
p.P(` return false`)
p.P(`}`)
}
}
func isScalar(kind protoreflect.Kind) bool {
switch kind {
case
protoreflect.BoolKind,
protoreflect.StringKind,
protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind,
protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind,
protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Sint64Kind,
protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.Sint32Kind,
protoreflect.EnumKind:
return true
}
return false
}

View File

@ -0,0 +1,422 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"fmt"
"strconv"
"strings"
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/descriptorpb"
)
const (
contextPackage = protogen.GoImportPath("context")
grpcPackage = protogen.GoImportPath("google.golang.org/grpc")
codesPackage = protogen.GoImportPath("google.golang.org/grpc/codes")
statusPackage = protogen.GoImportPath("google.golang.org/grpc/status")
)
// generateFileContent generates the gRPC service definitions, excluding the package statement.
func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *generator.GeneratedFile) {
if len(file.Services) == 0 {
return
}
g.P("// This is a compile-time assertion to ensure that this generated file")
g.P("// is compatible with the grpc package it is being compiled against.")
g.P("// Requires gRPC-Go v1.32.0 or later.")
g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion7")) // When changing, update version number above.
g.P()
for _, service := range file.Services {
genService(gen, file, g, service)
}
}
func genService(gen *protogen.Plugin, file *protogen.File, g *generator.GeneratedFile, service *protogen.Service) {
clientName := service.GoName + "Client"
g.P("// ", clientName, " is the client API for ", service.GoName, " service.")
g.P("//")
g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.")
// Client interface.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
}
g.Annotate(clientName, service.Location)
g.P("type ", clientName, " interface {")
for _, method := range service.Methods {
g.Annotate(clientName+"."+method.GoName, method.Location)
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P(method.Comments.Leading,
clientSignature(g, method))
}
g.P("}")
g.P()
// Client structure.
g.P("type ", unexport(clientName), " struct {")
g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
g.P("}")
g.P()
// NewClient factory.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {")
g.P("return &", unexport(clientName), "{cc}")
g.P("}")
g.P()
var methodIndex, streamIndex int
// Client method implementations.
for _, method := range service.Methods {
if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
// Unary RPC method
genClientMethod(gen, file, g, method, methodIndex)
methodIndex++
} else {
// Streaming RPC method
genClientMethod(gen, file, g, method, streamIndex)
streamIndex++
}
}
mustOrShould := "must"
if !*requireUnimplemented {
mustOrShould = "should"
}
// Server interface.
serverType := service.GoName + "Server"
g.P("// ", serverType, " is the server API for ", service.GoName, " service.")
g.P("// All implementations ", mustOrShould, " embed Unimplemented", serverType)
g.P("// for forward compatibility")
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
}
g.Annotate(serverType, service.Location)
g.P("type ", serverType, " interface {")
for _, method := range service.Methods {
g.Annotate(serverType+"."+method.GoName, method.Location)
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P(method.Comments.Leading,
serverSignature(g, method))
}
if *requireUnimplemented {
g.P("mustEmbedUnimplemented", serverType, "()")
}
g.P("}")
g.P()
// Server Unimplemented struct for forward compatibility.
g.P("// Unimplemented", serverType, " ", mustOrShould, " be embedded to have forward compatible implementations.")
g.P("type Unimplemented", serverType, " struct {")
g.P("}")
g.P()
for _, method := range service.Methods {
nilArg := ""
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
nilArg = "nil,"
}
g.P("func (Unimplemented", serverType, ") ", serverSignature(g, method), "{")
g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`)
g.P("}")
}
if *requireUnimplemented {
g.P("func (Unimplemented", serverType, ") mustEmbedUnimplemented", serverType, "() {}")
}
g.P()
// Unsafe Server interface to opt-out of forward compatibility.
g.P("// Unsafe", serverType, " may be embedded to opt out of forward compatibility for this service.")
g.P("// Use of this interface is not recommended, as added methods to ", serverType, " will")
g.P("// result in compilation errors.")
g.P("type Unsafe", serverType, " interface {")
g.P("mustEmbedUnimplemented", serverType, "()")
g.P("}")
// Server registration.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P(deprecationComment)
}
serviceDescVar := service.GoName + "_ServiceDesc"
g.P("func Register", service.GoName, "Server(s ", grpcPackage.Ident("ServiceRegistrar"), ", srv ", serverType, ") {")
g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
g.P("}")
g.P()
// Server handler implementations.
var handlerNames []string
for _, method := range service.Methods {
hname := genServerMethod(gen, file, g, method)
handlerNames = append(handlerNames, hname)
}
// Service descriptor.
g.P("// ", serviceDescVar, " is the ", grpcPackage.Ident("ServiceDesc"), " for ", service.GoName, " service.")
g.P("// It's only intended for direct use with ", grpcPackage.Ident("RegisterService"), ",")
g.P("// and not to be introspected or modified (even as a copy)")
g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
g.P("HandlerType: (*", serverType, ")(nil),")
g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
for i, method := range service.Methods {
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
g.P("},")
}
g.P("},")
g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
for i, method := range service.Methods {
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
if method.Desc.IsStreamingServer() {
g.P("ServerStreams: true,")
}
if method.Desc.IsStreamingClient() {
g.P("ClientStreams: true,")
}
g.P("},")
}
g.P("},")
g.P("Metadata: \"", file.Desc.Path(), "\",")
g.P("}")
g.P()
}
func clientSignature(g *generator.GeneratedFile, method *protogen.Method) string {
s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
if !method.Desc.IsStreamingClient() {
s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent)
}
s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") ("
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
} else {
s += method.Parent.GoName + "_" + method.GoName + "Client"
}
s += ", error)"
return s
}
func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *generator.GeneratedFile, method *protogen.Method, index int) {
service := method.Parent
sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{")
if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
// g.P("out := new(", method.Output.GoIdent, ")")
g.Alloc("out", method.Output, true)
g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`)
g.P("if err != nil { return nil, err }")
g.P("return out, nil")
g.P("}")
g.P()
return
}
streamType := unexport(service.GoName) + method.GoName + "Client"
serviceDescVar := service.GoName + "_ServiceDesc"
g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`)
g.P("if err != nil { return nil, err }")
g.P("x := &", streamType, "{stream}")
if !method.Desc.IsStreamingClient() {
g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
}
g.P("return x, nil")
g.P("}")
g.P()
genSend := method.Desc.IsStreamingClient()
genRecv := method.Desc.IsStreamingServer()
genCloseAndRecv := !method.Desc.IsStreamingServer()
// Stream auxiliary types and methods.
g.P("type ", service.GoName, "_", method.GoName, "Client interface {")
if genSend {
g.P("Send(*", method.Input.GoIdent, ") error")
}
if genRecv {
g.P("Recv() (*", method.Output.GoIdent, ", error)")
}
if genCloseAndRecv {
g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
}
g.P(grpcPackage.Ident("ClientStream"))
g.P("}")
g.P()
g.P("type ", streamType, " struct {")
g.P(grpcPackage.Ident("ClientStream"))
g.P("}")
g.P()
if genSend {
g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
g.P("return x.ClientStream.SendMsg(m)")
g.P("}")
g.P()
}
if genRecv {
g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
// g.P("m := new(", method.Output.GoIdent, ")")
g.Alloc("m", method.Output, true)
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
if genCloseAndRecv {
g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
// g.P("m := new(", method.Output.GoIdent, ")")
g.Alloc("m", method.Output, true)
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
}
func serverSignature(g *generator.GeneratedFile, method *protogen.Method) string {
var reqArgs []string
ret := "error"
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context")))
ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)"
}
if !method.Desc.IsStreamingClient() {
reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent))
}
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
}
return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
}
func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *generator.GeneratedFile, method *protogen.Method) string {
service := method.Parent
hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
// g.P("in := new(", method.Input.GoIdent, ")")
g.Alloc("in", method.Input, true)
g.P("if err := dec(in); err != nil { return nil, err }")
g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{")
g.P("Server: srv,")
g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())), ",")
g.P("}")
g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {")
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))")
g.P("}")
g.P("return interceptor(ctx, in, info, handler)")
g.P("}")
g.P()
return hname
}
streamType := unexport(service.GoName) + method.GoName + "Server"
g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
if !method.Desc.IsStreamingClient() {
// g.P("m := new(", method.Input.GoIdent, ")")
g.Alloc("m", method.Input, true)
g.P("if err := stream.RecvMsg(m); err != nil { return err }")
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})")
} else {
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})")
}
g.P("}")
g.P()
genSend := method.Desc.IsStreamingServer()
genSendAndClose := !method.Desc.IsStreamingServer()
genRecv := method.Desc.IsStreamingClient()
// Stream auxiliary types and methods.
g.P("type ", service.GoName, "_", method.GoName, "Server interface {")
if genSend {
g.P("Send(*", method.Output.GoIdent, ") error")
}
if genSendAndClose {
g.P("SendAndClose(*", method.Output.GoIdent, ") error")
}
if genRecv {
g.P("Recv() (*", method.Input.GoIdent, ", error)")
}
g.P(grpcPackage.Ident("ServerStream"))
g.P("}")
g.P()
g.P("type ", streamType, " struct {")
g.P(grpcPackage.Ident("ServerStream"))
g.P("}")
g.P()
if genSend {
g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
g.P("return x.ServerStream.SendMsg(m)")
g.P("}")
g.P()
}
if genSendAndClose {
g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
g.P("return x.ServerStream.SendMsg(m)")
g.P("}")
g.P()
}
if genRecv {
g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
// g.P("m := new(", method.Input.GoIdent, ")")
g.Alloc("m", method.Input, true)
g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
return hname
}
const deprecationComment = "// Deprecated: Do not use."
func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }

View File

@ -0,0 +1,34 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package grpc
import (
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
)
const version = "1.1.0-vtproto"
var requireUnimplementedAlways = true
var requireUnimplemented = &requireUnimplementedAlways
func init() {
generator.RegisterFeature("grpc", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &grpc{gen}
})
}
type grpc struct {
*generator.GeneratedFile
}
func (g *grpc) GenerateFile(file *protogen.File) bool {
if len(file.Services) == 0 {
return false
}
generateFileContent(nil, file, g.GeneratedFile)
return true
}

View File

@ -0,0 +1,748 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package marshal
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
)
func init() {
generator.RegisterFeature("marshal", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &marshal{GeneratedFile: gen, Stable: false, strict: false}
})
generator.RegisterFeature("marshal_strict", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &marshal{GeneratedFile: gen, Stable: false, strict: true}
})
}
type counter int
func (cnt *counter) Next() string {
*cnt++
return cnt.Current()
}
func (cnt *counter) Current() string {
return strconv.Itoa(int(*cnt))
}
type marshal struct {
*generator.GeneratedFile
Stable, once, strict bool
}
var _ generator.FeatureGenerator = (*marshal)(nil)
func (p *marshal) GenerateFile(file *protogen.File) bool {
for _, message := range file.Messages {
p.message(message)
}
return p.once
}
func (p *marshal) encodeFixed64(varName ...string) {
p.P(`i -= 8`)
p.P(p.Ident("encoding/binary", "LittleEndian"), `.PutUint64(dAtA[i:], uint64(`, strings.Join(varName, ""), `))`)
}
func (p *marshal) encodeFixed32(varName ...string) {
p.P(`i -= 4`)
p.P(p.Ident("encoding/binary", "LittleEndian"), `.PutUint32(dAtA[i:], uint32(`, strings.Join(varName, ""), `))`)
}
func (p *marshal) encodeVarint(varName ...string) {
p.P(`i = `, p.Helper("EncodeVarint"), `(dAtA, i, uint64(`, strings.Join(varName, ""), `))`)
}
func (p *marshal) encodeKey(fieldNumber protoreflect.FieldNumber, wireType protowire.Type) {
x := uint32(fieldNumber)<<3 | uint32(wireType)
i := 0
keybuf := make([]byte, 0)
for i = 0; x > 127; i++ {
keybuf = append(keybuf, 0x80|uint8(x&0x7F))
x >>= 7
}
keybuf = append(keybuf, uint8(x))
for i = len(keybuf) - 1; i >= 0; i-- {
p.P(`i--`)
p.P(`dAtA[i] = `, fmt.Sprintf("%#v", keybuf[i]))
}
}
func (p *marshal) mapField(kvField *protogen.Field, varName string) {
switch kvField.Desc.Kind() {
case protoreflect.DoubleKind:
p.encodeFixed64(p.Ident("math", "Float64bits"), `(float64(`, varName, `))`)
case protoreflect.FloatKind:
p.encodeFixed32(p.Ident("math", "Float32bits"), `(float32(`, varName, `))`)
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.EnumKind:
p.encodeVarint(varName)
case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
p.encodeFixed64(varName)
case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
p.encodeFixed32(varName)
case protoreflect.BoolKind:
p.P(`i--`)
p.P(`if `, varName, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
case protoreflect.StringKind, protoreflect.BytesKind:
p.P(`i -= len(`, varName, `)`)
p.P(`copy(dAtA[i:], `, varName, `)`)
p.encodeVarint(`len(`, varName, `)`)
case protoreflect.Sint32Kind:
p.encodeVarint(`(uint32(`, varName, `) << 1) ^ uint32((`, varName, ` >> 31))`)
case protoreflect.Sint64Kind:
p.encodeVarint(`(uint64(`, varName, `) << 1) ^ uint64((`, varName, ` >> 63))`)
case protoreflect.MessageKind:
p.marshalBackward(varName, true, kvField.Message)
}
}
func (p *marshal) field(oneof bool, numGen *counter, field *protogen.Field) {
fieldname := field.GoName
nullable := field.Message != nil || (!oneof && field.Desc.HasPresence())
repeated := field.Desc.Cardinality() == protoreflect.Repeated
if repeated {
p.P(`if len(m.`, fieldname, `) > 0 {`)
} else if nullable {
if field.Desc.Cardinality() == protoreflect.Required {
p.P(`if m.`, fieldname, ` == nil {`)
p.P(`return 0, `, p.Ident("fmt", "Errorf"), `("proto: required field `, field.Desc.Name(), ` not set")`)
p.P(`} else {`)
} else {
p.P(`if m.`, fieldname, ` != nil {`)
}
}
packed := field.Desc.IsPacked()
wireType := generator.ProtoWireType(field.Desc.Kind())
fieldNumber := field.Desc.Number()
if packed {
wireType = protowire.BytesType
}
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.P(`f`, numGen.Next(), ` := `, p.Ident("math", "Float64bits"), `(float64(`, val, `))`)
p.encodeFixed64("f", numGen.Current())
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `) * 8`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`f`, numGen.Next(), ` := `, p.Ident("math", "Float64bits"), `(float64(`, val, `))`)
p.encodeFixed64("f", numGen.Current())
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeFixed64(p.Ident("math", "Float64bits"), `(float64(*m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeFixed64(p.Ident("math", "Float64bits"), `(float64(m.`, fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeFixed64(p.Ident("math", "Float64bits"), `(float64(m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.FloatKind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.P(`f`, numGen.Next(), ` := `, p.Ident("math", "Float32bits"), `(float32(`, val, `))`)
p.encodeFixed32("f" + numGen.Current())
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `) * 4`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`f`, numGen.Next(), ` := `, p.Ident("math", "Float32bits"), `(float32(`, val, `))`)
p.encodeFixed32("f" + numGen.Current())
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeFixed32(p.Ident("math", "Float32bits"), `(float32(*m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeFixed32(p.Ident("math", "Float32bits"), `(float32(m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeFixed32(p.Ident("math", "Float32bits"), `(float32(m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.EnumKind:
if packed {
jvar := "j" + numGen.Next()
total := "pksize" + numGen.Next()
p.P(`var `, total, ` int`)
p.P(`for _, num := range m.`, fieldname, ` {`)
p.P(total, ` += `, p.Helper("SizeOfVarint"), `(uint64(num))`)
p.P(`}`)
p.P(`i -= `, total)
p.P(jvar, `:= i`)
switch field.Desc.Kind() {
case protoreflect.Int64Kind, protoreflect.Int32Kind, protoreflect.EnumKind:
p.P(`for _, num1 := range m.`, fieldname, ` {`)
p.P(`num := uint64(num1)`)
default:
p.P(`for _, num := range m.`, fieldname, ` {`)
}
p.P(`for num >= 1<<7 {`)
p.P(`dAtA[`, jvar, `] = uint8(uint64(num)&0x7f|0x80)`)
p.P(`num >>= 7`)
p.P(jvar, `++`)
p.P(`}`)
p.P(`dAtA[`, jvar, `] = uint8(num)`)
p.P(jvar, `++`)
p.P(`}`)
p.encodeVarint(total)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.encodeVarint(val)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeVarint(`*m.`, fieldname)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeVarint(`m.`, fieldname)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeVarint(`m.`, fieldname)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.encodeFixed64(val)
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `) * 8`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.encodeFixed64(val)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeFixed64("*m.", fieldname)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeFixed64("m.", fieldname)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeFixed64("m.", fieldname)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.encodeFixed32(val)
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `) * 4`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.encodeFixed32(val)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeFixed32("*m." + fieldname)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeFixed32("m." + fieldname)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeFixed32("m." + fieldname)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.BoolKind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i--`)
p.P(`if `, val, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i--`)
p.P(`if `, val, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.P(`i--`)
p.P(`if *m.`, fieldname, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` {`)
p.P(`i--`)
p.P(`if m.`, fieldname, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.P(`i--`)
p.P(`if m.`, fieldname, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.StringKind:
if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i -= len(`, val, `)`)
p.P(`copy(dAtA[i:], `, val, `)`)
p.encodeVarint(`len(`, val, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.P(`i -= len(*m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], *m.`, fieldname, `)`)
p.encodeVarint(`len(*m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if len(m.`, fieldname, `) > 0 {`)
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.GroupKind:
p.encodeKey(fieldNumber, protowire.EndGroupType)
p.marshalBackward(`m.`+fieldname, false, field.Message)
p.encodeKey(fieldNumber, protowire.StartGroupType)
case protoreflect.MessageKind:
if field.Desc.IsMap() {
goTypK, _ := p.FieldGoType(field.Message.Fields[0])
keyKind := field.Message.Fields[0].Desc.Kind()
valKind := field.Message.Fields[1].Desc.Kind()
var val string
if p.Stable && keyKind != protoreflect.BoolKind {
keysName := `keysFor` + fieldname
p.P(keysName, ` := make([]`, goTypK, `, 0, len(m.`, fieldname, `))`)
p.P(`for k := range m.`, fieldname, ` {`)
p.P(keysName, ` = append(`, keysName, `, `, goTypK, `(k))`)
p.P(`}`)
p.P(p.Ident("sort", "Slice"), `(`, keysName, `, func(i, j int) bool {`)
p.P(`return `, keysName, `[i] < `, keysName, `[j]`)
p.P(`})`)
val = p.reverseListRange(keysName)
} else {
p.P(`for k := range m.`, fieldname, ` {`)
val = "k"
}
if p.Stable {
p.P(`v := m.`, fieldname, `[`, goTypK, `(`, val, `)]`)
} else {
p.P(`v := m.`, fieldname, `[`, val, `]`)
}
p.P(`baseI := i`)
accessor := `v`
p.mapField(field.Message.Fields[1], accessor)
p.encodeKey(2, generator.ProtoWireType(valKind))
p.mapField(field.Message.Fields[0], val)
p.encodeKey(1, generator.ProtoWireType(keyKind))
p.encodeVarint(`baseI - i`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.marshalBackward(val, true, field.Message)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.marshalBackward(`m.`+fieldname, true, field.Message)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.BytesKind:
if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i -= len(`, val, `)`)
p.P(`copy(dAtA[i:], `, val, `)`)
p.encodeVarint(`len(`, val, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if !oneof && !field.Desc.HasPresence() {
p.P(`if len(m.`, fieldname, `) > 0 {`)
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Sint32Kind:
if packed {
jvar := "j" + numGen.Next()
total := "pksize" + numGen.Next()
p.P(`var `, total, ` int`)
p.P(`for _, num := range m.`, fieldname, ` {`)
p.P(total, ` += `, p.Helper("SizeOfZigzag"), `(uint64(num))`)
p.P(`}`)
p.P(`i -= `, total)
p.P(jvar, `:= i`)
p.P(`for _, num := range m.`, fieldname, ` {`)
xvar := "x" + numGen.Next()
p.P(xvar, ` := (uint32(num) << 1) ^ uint32((num >> 31))`)
p.P(`for `, xvar, ` >= 1<<7 {`)
p.P(`dAtA[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`)
p.P(jvar, `++`)
p.P(xvar, ` >>= 7`)
p.P(`}`)
p.P(`dAtA[`, jvar, `] = uint8(`, xvar, `)`)
p.P(jvar, `++`)
p.P(`}`)
p.encodeVarint(total)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`x`, numGen.Next(), ` := (uint32(`, val, `) << 1) ^ uint32((`, val, ` >> 31))`)
p.encodeVarint(`x`, numGen.Current())
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeVarint(`(uint32(*m.`, fieldname, `) << 1) ^ uint32((*m.`, fieldname, ` >> 31))`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Sint64Kind:
if packed {
jvar := "j" + numGen.Next()
total := "pksize" + numGen.Next()
p.P(`var `, total, ` int`)
p.P(`for _, num := range m.`, fieldname, ` {`)
p.P(total, ` += `, p.Helper("SizeOfZigzag"), `(uint64(num))`)
p.P(`}`)
p.P(`i -= `, total)
p.P(jvar, `:= i`)
p.P(`for _, num := range m.`, fieldname, ` {`)
xvar := "x" + numGen.Next()
p.P(xvar, ` := (uint64(num) << 1) ^ uint64((num >> 63))`)
p.P(`for `, xvar, ` >= 1<<7 {`)
p.P(`dAtA[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`)
p.P(jvar, `++`)
p.P(xvar, ` >>= 7`)
p.P(`}`)
p.P(`dAtA[`, jvar, `] = uint8(`, xvar, `)`)
p.P(jvar, `++`)
p.P(`}`)
p.encodeVarint(total)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`x`, numGen.Next(), ` := (uint64(`, val, `) << 1) ^ uint64((`, val, ` >> 63))`)
p.encodeVarint("x" + numGen.Current())
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeVarint(`(uint64(*m.`, fieldname, `) << 1) ^ uint64((*m.`, fieldname, ` >> 63))`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`)
p.encodeKey(fieldNumber, wireType)
}
default:
panic("not implemented")
}
// Empty protobufs should emit a message or compatibility with Golang protobuf;
// See https://github.com/planetscale/vtprotobuf/issues/61
if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() {
p.P("} else {")
p.P("i = protohelpers.EncodeVarint(dAtA, i, 0)")
p.encodeKey(fieldNumber, wireType)
p.P("}")
} else if repeated || nullable {
p.P(`}`)
}
}
func (p *marshal) methodMarshalToSizedBuffer() string {
switch {
case p.strict:
return "MarshalToSizedBufferVTStrict"
default:
return "MarshalToSizedBufferVT"
}
}
func (p *marshal) methodMarshalTo() string {
switch {
case p.strict:
return "MarshalToVTStrict"
default:
return "MarshalToVT"
}
}
func (p *marshal) methodMarshal() string {
switch {
case p.strict:
return "MarshalVTStrict"
default:
return "MarshalVT"
}
}
func (p *marshal) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
var numGen counter
ccTypeName := message.GoIdent.GoName
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshal(), `() (dAtA []byte, err error) {`)
p.P(`if m == nil {`)
p.P(`return nil, nil`)
p.P(`}`)
p.P(`size := m.SizeVT()`)
p.P(`dAtA = make([]byte, size)`)
p.P(`n, err := m.`, p.methodMarshalToSizedBuffer(), `(dAtA[:size])`)
p.P(`if err != nil {`)
p.P(`return nil, err`)
p.P(`}`)
p.P(`return dAtA[:n], nil`)
p.P(`}`)
p.P(``)
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalTo(), `(dAtA []byte) (int, error) {`)
p.P(`size := m.SizeVT()`)
p.P(`return m.`, p.methodMarshalToSizedBuffer(), `(dAtA[:size])`)
p.P(`}`)
p.P(``)
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalToSizedBuffer(), `(dAtA []byte) (int, error) {`)
p.P(`if m == nil {`)
p.P(`return 0, nil`)
p.P(`}`)
p.P(`i := len(dAtA)`)
p.P(`_ = i`)
p.P(`var l int`)
p.P(`_ = l`)
if !p.Wrapper() {
p.P(`if m.unknownFields != nil {`)
p.P(`i -= len(m.unknownFields)`)
p.P(`copy(dAtA[i:], m.unknownFields)`)
p.P(`}`)
}
sort.Slice(message.Fields, func(i, j int) bool {
return message.Fields[i].Desc.Number() < message.Fields[j].Desc.Number()
})
marshalForwardOneOf := func(varname ...any) {
l := []any{`size, err := `}
l = append(l, varname...)
l = append(l, `.`, p.methodMarshalToSizedBuffer(), `(dAtA[:i])`)
p.P(l...)
p.P(`if err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.P(`i -= size`)
}
if p.strict {
for i := len(message.Fields) - 1; i >= 0; i-- {
field := message.Fields[i]
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(false, &numGen, field)
} else {
if p.IsWellKnownType(message) {
p.P(`if m, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent, `); ok {`)
p.P(`msg := ((*`, p.WellKnownFieldMap(field), `)(m))`)
} else {
p.P(`if msg, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent.GoName, `); ok {`)
}
marshalForwardOneOf("msg")
p.P(`}`)
}
}
} else {
// To match the wire format of proto.Marshal, oneofs have to be marshaled
// before fields. See https://github.com/planetscale/vtprotobuf/pull/22
oneofs := make(map[string]struct{}, len(message.Fields))
for i := len(message.Fields) - 1; i >= 0; i-- {
field := message.Fields[i]
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if oneof {
fieldname := field.Oneof.GoName
if _, ok := oneofs[fieldname]; ok {
continue
}
oneofs[fieldname] = struct{}{}
if p.IsWellKnownType(message) {
p.P(`switch c := m.`, fieldname, `.(type) {`)
for _, f := range field.Oneof.Fields {
p.P(`case *`, f.GoIdent, `:`)
marshalForwardOneOf(`(*`, p.WellKnownFieldMap(f), `)(c)`)
}
p.P(`}`)
} else {
p.P(`if vtmsg, ok := m.`, fieldname, `.(interface{`)
p.P(p.methodMarshalToSizedBuffer(), ` ([]byte) (int, error)`)
p.P(`}); ok {`)
marshalForwardOneOf("vtmsg")
p.P(`}`)
}
}
}
for i := len(message.Fields) - 1; i >= 0; i-- {
field := message.Fields[i]
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(false, &numGen, field)
}
}
}
p.P(`return len(dAtA) - i, nil`)
p.P(`}`)
p.P()
// Generate MarshalToVT methods for oneof fields
for _, field := range message.Fields {
if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() {
continue
}
ccTypeName := field.GoIdent.GoName
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalTo(), `(dAtA []byte) (int, error) {`)
p.P(`size := m.SizeVT()`)
p.P(`return m.`, p.methodMarshalToSizedBuffer(), `(dAtA[:size])`)
p.P(`}`)
p.P(``)
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalToSizedBuffer(), `(dAtA []byte) (int, error) {`)
p.P(`i := len(dAtA)`)
p.field(true, &numGen, field)
p.P(`return len(dAtA) - i, nil`)
p.P(`}`)
}
}
func (p *marshal) reverseListRange(expression ...string) string {
exp := strings.Join(expression, "")
p.P(`for iNdEx := len(`, exp, `) - 1; iNdEx >= 0; iNdEx-- {`)
return exp + `[iNdEx]`
}
func (p *marshal) marshalBackwardSize(varInt bool) {
p.P(`if err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.P(`i -= size`)
if varInt {
p.encodeVarint(`size`)
}
}
func (p *marshal) marshalBackward(varName string, varInt bool, message *protogen.Message) {
switch {
case p.IsWellKnownType(message):
p.P(`size, err := (*`, p.WellKnownTypeMap(message), `)(`, varName, `).`, p.methodMarshalToSizedBuffer(), `(dAtA[:i])`)
p.marshalBackwardSize(varInt)
case p.IsLocalMessage(message):
p.P(`size, err := `, varName, `.`, p.methodMarshalToSizedBuffer(), `(dAtA[:i])`)
p.marshalBackwardSize(varInt)
default:
p.P(`if vtmsg, ok := interface{}(`, varName, `).(interface{`)
p.P(p.methodMarshalToSizedBuffer(), `([]byte) (int, error)`)
p.P(`}); ok{`)
p.P(`size, err := vtmsg.`, p.methodMarshalToSizedBuffer(), `(dAtA[:i])`)
p.marshalBackwardSize(varInt)
p.P(`} else {`)
p.P(`encoded, err := `, p.Ident(generator.ProtoPkg, "Marshal"), `(`, varName, `)`)
p.P(`if err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.P(`i -= len(encoded)`)
p.P(`copy(dAtA[i:], encoded)`)
if varInt {
p.encodeVarint(`len(encoded)`)
}
p.P(`}`)
}
}

View File

@ -0,0 +1,105 @@
package pool
import (
"fmt"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
func init() {
generator.RegisterFeature("pool", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &pool{GeneratedFile: gen}
})
}
type pool struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*pool)(nil)
func (p *pool) GenerateFile(file *protogen.File) bool {
for _, message := range file.Messages {
p.message(message)
}
return p.once
}
func (p *pool) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(nested)
}
if message.Desc.IsMapEntry() || !p.ShouldPool(message) {
return
}
p.once = true
ccTypeName := message.GoIdent
p.P(`var vtprotoPool_`, ccTypeName, ` = `, p.Ident("sync", "Pool"), `{`)
p.P(`New: func() interface{} {`)
p.P(`return &`, ccTypeName, `{}`)
p.P(`},`)
p.P(`}`)
p.P(`func (m *`, ccTypeName, `) ResetVT() {`)
p.P(`if m != nil {`)
var saved []*protogen.Field
for _, field := range message.Fields {
fieldName := field.GoName
if field.Desc.IsList() {
switch field.Desc.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
p.P(`for _, mm := range m.`, fieldName, `{`)
if p.ShouldPool(field.Message) {
p.P(`mm.ResetVT()`)
} else {
p.P(`mm.Reset()`)
}
p.P(`}`)
}
p.P(fmt.Sprintf("f%d", len(saved)), ` := m.`, fieldName, `[:0]`)
saved = append(saved, field)
} else if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
if p.ShouldPool(field.Message) {
p.P(`if oneof, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent, `); ok {`)
p.P(`oneof.`, fieldName, `.ReturnToVTPool()`)
p.P(`}`)
}
} else {
switch field.Desc.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
if !field.Desc.IsMap() && p.ShouldPool(field.Message) {
p.P(`m.`, fieldName, `.ReturnToVTPool()`)
}
case protoreflect.BytesKind:
p.P(fmt.Sprintf("f%d", len(saved)), ` := m.`, fieldName, `[:0]`)
saved = append(saved, field)
}
}
}
p.P(`m.Reset()`)
for i, field := range saved {
p.P(`m.`, field.GoName, ` = `, fmt.Sprintf("f%d", i))
}
p.P(`}`)
p.P(`}`)
p.P(`func (m *`, ccTypeName, `) ReturnToVTPool() {`)
p.P(`if m != nil {`)
p.P(`m.ResetVT()`)
p.P(`vtprotoPool_`, ccTypeName, `.Put(m)`)
p.P(`}`)
p.P(`}`)
p.P(`func `, ccTypeName, `FromVTPool() *`, ccTypeName, `{`)
p.P(`return vtprotoPool_`, ccTypeName, `.Get().(*`, ccTypeName, `)`)
p.P(`}`)
}

View File

@ -0,0 +1,348 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package size
import (
"strconv"
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
)
func init() {
generator.RegisterFeature("size", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &size{GeneratedFile: gen}
})
}
type size struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*size)(nil)
func (p *size) Name() string {
return "size"
}
func (p *size) GenerateFile(file *protogen.File) bool {
for _, message := range file.Messages {
p.message(message)
}
return p.once
}
func (p *size) messageSize(varName, sizeName string, message *protogen.Message) {
switch {
case p.IsWellKnownType(message):
p.P(`l = (*`, p.WellKnownTypeMap(message), `)(`, varName, `).`, sizeName, `()`)
case p.IsLocalMessage(message):
p.P(`l = `, varName, `.`, sizeName, `()`)
default:
p.P(`if size, ok := interface{}(`, varName, `).(interface{`)
p.P(sizeName, `() int`)
p.P(`}); ok{`)
p.P(`l = size.`, sizeName, `()`)
p.P(`} else {`)
p.P(`l = `, p.Ident(generator.ProtoPkg, "Size"), `(`, varName, `)`)
p.P(`}`)
}
}
func (p *size) field(oneof bool, field *protogen.Field, sizeName string) {
fieldname := field.GoName
nullable := field.Message != nil || (!oneof && field.Desc.HasPresence())
repeated := field.Desc.Cardinality() == protoreflect.Repeated
if repeated {
p.P(`if len(m.`, fieldname, `) > 0 {`)
} else if nullable {
p.P(`if m.`, fieldname, ` != nil {`)
}
packed := field.Desc.IsPacked()
wireType := generator.ProtoWireType(field.Desc.Kind())
fieldNumber := field.Desc.Number()
if packed {
wireType = protowire.BytesType
}
key := generator.KeySize(fieldNumber, wireType)
switch field.Desc.Kind() {
case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
if packed {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(len(m.`, fieldname, `)*8))`, `+len(m.`, fieldname, `)*8`)
} else if repeated {
p.P(`n+=`, strconv.Itoa(key+8), `*len(m.`, fieldname, `)`)
} else if !oneof && !nullable {
p.P(`if m.`, fieldname, ` != 0 {`)
p.P(`n+=`, strconv.Itoa(key+8))
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key+8))
}
case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
if packed {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(len(m.`, fieldname, `)*4))`, `+len(m.`, fieldname, `)*4`)
} else if repeated {
p.P(`n+=`, strconv.Itoa(key+4), `*len(m.`, fieldname, `)`)
} else if !oneof && !nullable {
p.P(`if m.`, fieldname, ` != 0 {`)
p.P(`n+=`, strconv.Itoa(key+4))
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key+4))
}
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
if packed {
p.P(`l = 0`)
p.P(`for _, e := range m.`, fieldname, ` {`)
p.P(`l+=`, p.Helper("SizeOfVarint"), `(uint64(e))`)
p.P(`}`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(l))+l`)
} else if repeated {
p.P(`for _, e := range m.`, fieldname, ` {`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(e))`)
p.P(`}`)
} else if nullable {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(*m.`, fieldname, `))`)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(m.`, fieldname, `))`)
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(m.`, fieldname, `))`)
}
case protoreflect.BoolKind:
if packed {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(len(m.`, fieldname, `)))`, `+len(m.`, fieldname, `)*1`)
} else if repeated {
p.P(`n+=`, strconv.Itoa(key+1), `*len(m.`, fieldname, `)`)
} else if !oneof && !nullable {
p.P(`if m.`, fieldname, ` {`)
p.P(`n+=`, strconv.Itoa(key+1))
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key+1))
}
case protoreflect.StringKind:
if repeated {
p.P(`for _, s := range m.`, fieldname, ` { `)
p.P(`l = len(s)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else if nullable {
p.P(`l=len(*m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
} else if !oneof {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`if l > 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
}
case protoreflect.GroupKind:
p.messageSize("m."+fieldname, sizeName, field.Message)
p.P(`n+=l+`, strconv.Itoa(2*key))
case protoreflect.MessageKind:
if field.Desc.IsMap() {
fieldKeySize := generator.KeySize(field.Desc.Number(), generator.ProtoWireType(field.Desc.Kind()))
keyKeySize := generator.KeySize(1, generator.ProtoWireType(field.Message.Fields[0].Desc.Kind()))
valueKeySize := generator.KeySize(2, generator.ProtoWireType(field.Message.Fields[1].Desc.Kind()))
p.P(`for k, v := range m.`, fieldname, ` { `)
p.P(`_ = k`)
p.P(`_ = v`)
sum := []interface{}{strconv.Itoa(keyKeySize)}
switch field.Message.Fields[0].Desc.Kind() {
case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
sum = append(sum, `8`)
case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
sum = append(sum, `4`)
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
sum = append(sum, p.Helper("SizeOfVarint"), `(uint64(k))`)
case protoreflect.BoolKind:
sum = append(sum, `1`)
case protoreflect.StringKind, protoreflect.BytesKind:
sum = append(sum, `len(k)`, p.Helper("SizeOfVarint"), `(uint64(len(k)))`)
case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
sum = append(sum, p.Helper("SizeOfZigzag"), `(uint64(k))`)
}
switch field.Message.Fields[1].Desc.Kind() {
case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, strconv.Itoa(8))
case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, strconv.Itoa(4))
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, p.Helper("SizeOfVarint"), `(uint64(v))`)
case protoreflect.BoolKind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, `1`)
case protoreflect.StringKind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, `len(v)`, p.Helper("SizeOfVarint"), `(uint64(len(v)))`)
case protoreflect.BytesKind:
p.P(`l = `, strconv.Itoa(valueKeySize), ` + len(v)+`, p.Helper("SizeOfVarint"), `(uint64(len(v)))`)
sum = append(sum, `l`)
case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, p.Helper("SizeOfZigzag"), `(uint64(v))`)
case protoreflect.MessageKind:
p.P(`l = 0`)
p.P(`if v != nil {`)
p.messageSize("v", sizeName, field.Message.Fields[1].Message)
p.P(`}`)
p.P(`l += `, strconv.Itoa(valueKeySize), ` + `, p.Helper("SizeOfVarint"), `(uint64(l))`)
sum = append(sum, `l`)
}
mapEntrySize := []interface{}{"mapEntrySize := "}
for i, elt := range sum {
mapEntrySize = append(mapEntrySize, elt)
// if elt is not a string, then it is a helper function call
if _, ok := elt.(string); ok && i < len(sum)-1 {
mapEntrySize = append(mapEntrySize, "+")
}
}
p.P(mapEntrySize...)
p.P(`n+=mapEntrySize+`, fieldKeySize, `+`, p.Helper("SizeOfVarint"), `(uint64(mapEntrySize))`)
p.P(`}`)
} else if field.Desc.IsList() {
p.P(`for _, e := range m.`, fieldname, ` { `)
p.messageSize("e", sizeName, field.Message)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else {
p.messageSize("m."+fieldname, sizeName, field.Message)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
}
case protoreflect.BytesKind:
if repeated {
p.P(`for _, b := range m.`, fieldname, ` { `)
p.P(`l = len(b)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else if !oneof && !field.Desc.HasPresence() {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`if l > 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
}
case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
if packed {
p.P(`l = 0`)
p.P(`for _, e := range m.`, fieldname, ` {`)
p.P(`l+=`, p.Helper("SizeOfZigzag"), `(uint64(e))`)
p.P(`}`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(l))+l`)
} else if repeated {
p.P(`for _, e := range m.`, fieldname, ` {`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfZigzag"), `(uint64(e))`)
p.P(`}`)
} else if nullable {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfZigzag"), `(uint64(*m.`, fieldname, `))`)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfZigzag"), `(uint64(m.`, fieldname, `))`)
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfZigzag"), `(uint64(m.`, fieldname, `))`)
}
default:
panic("not implemented")
}
// Empty protobufs should emit a message or compatibility with Golang protobuf;
// See https://github.com/planetscale/vtprotobuf/issues/61
// Size is always 3 so just hardcode that here
if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() {
p.P("} else { n += 3 }")
} else if repeated || nullable {
p.P(`}`)
}
}
func (p *size) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
sizeName := "SizeVT"
ccTypeName := message.GoIdent.GoName
p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
p.P(`if m == nil {`)
p.P(`return 0`)
p.P(`}`)
p.P(`var l int`)
p.P(`_ = l`)
oneofs := make(map[string]struct{})
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(false, field, sizeName)
} else {
fieldname := field.Oneof.GoName
if _, ok := oneofs[fieldname]; ok {
continue
}
oneofs[fieldname] = struct{}{}
if p.IsWellKnownType(message) {
p.P(`switch c := m.`, fieldname, `.(type) {`)
for _, f := range field.Oneof.Fields {
p.P(`case *`, f.GoIdent, `:`)
p.P(`n += (*`, p.WellKnownFieldMap(f), `)(c).`, sizeName, `()`)
}
p.P(`}`)
} else {
p.P(`if vtmsg, ok := m.`, fieldname, `.(interface{ SizeVT() int }); ok {`)
p.P(`n+=vtmsg.`, sizeName, `()`)
p.P(`}`)
}
}
}
if !p.Wrapper() {
p.P(`n+=len(m.unknownFields)`)
}
p.P(`return n`)
p.P(`}`)
p.P()
for _, field := range message.Fields {
if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() {
continue
}
ccTypeName := field.GoIdent
if p.IsWellKnownType(message) && p.IsLocalMessage(message) {
ccTypeName.GoImportPath = ""
}
p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
p.P(`if m == nil {`)
p.P(`return 0`)
p.P(`}`)
p.P(`var l int`)
p.P(`_ = l`)
p.field(true, field, sizeName)
p.P(`return n`)
p.P(`}`)
}
}

View File

@ -0,0 +1,872 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unmarshal
import (
"fmt"
"strconv"
"strings"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
func init() {
generator.RegisterFeature("unmarshal", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &unmarshal{GeneratedFile: gen}
})
generator.RegisterFeature("unmarshal_unsafe", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &unmarshal{GeneratedFile: gen, unsafe: true}
})
}
type unmarshal struct {
*generator.GeneratedFile
unsafe bool
once bool
}
var _ generator.FeatureGenerator = (*unmarshal)(nil)
func (p *unmarshal) GenerateFile(file *protogen.File) bool {
proto3 := file.Desc.Syntax() == protoreflect.Proto3
for _, message := range file.Messages {
p.message(proto3, message)
}
return p.once
}
func (p *unmarshal) methodUnmarshal() string {
if p.unsafe {
return "UnmarshalVTUnsafe"
}
return "UnmarshalVT"
}
func (p *unmarshal) decodeMessage(varName, buf string, message *protogen.Message) {
switch {
case p.IsWellKnownType(message):
p.P(`if err := (*`, p.WellKnownTypeMap(message), `)(`, varName, `).`, p.methodUnmarshal(), `(`, buf, `); err != nil {`)
p.P(`return err`)
p.P(`}`)
case p.IsLocalMessage(message):
p.P(`if err := `, varName, `.`, p.methodUnmarshal(), `(`, buf, `); err != nil {`)
p.P(`return err`)
p.P(`}`)
default:
p.P(`if unmarshal, ok := interface{}(`, varName, `).(interface{`)
p.P(p.methodUnmarshal(), `([]byte) error`)
p.P(`}); ok{`)
p.P(`if err := unmarshal.`, p.methodUnmarshal(), `(`, buf, `); err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`} else {`)
p.P(`if err := `, p.Ident(generator.ProtoPkg, "Unmarshal"), `(`, buf, `, `, varName, `); err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`}`)
}
}
func (p *unmarshal) decodeVarint(varName string, typName string) {
p.P(`for shift := uint(0); ; shift += 7 {`)
p.P(`if shift >= 64 {`)
p.P(`return `, p.Helper("ErrIntOverflow"))
p.P(`}`)
p.P(`if iNdEx >= l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(`b := dAtA[iNdEx]`)
p.P(`iNdEx++`)
p.P(varName, ` |= `, typName, `(b&0x7F) << shift`)
p.P(`if b < 0x80 {`)
p.P(`break`)
p.P(`}`)
p.P(`}`)
}
func (p *unmarshal) decodeFixed32(varName string, typeName string) {
p.P(`if (iNdEx+4) > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(varName, ` = `, typeName, `(`, p.Ident("encoding/binary", "LittleEndian"), `.Uint32(dAtA[iNdEx:]))`)
p.P(`iNdEx += 4`)
}
func (p *unmarshal) decodeFixed64(varName string, typeName string) {
p.P(`if (iNdEx+8) > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(varName, ` = `, typeName, `(`, p.Ident("encoding/binary", "LittleEndian"), `.Uint64(dAtA[iNdEx:]))`)
p.P(`iNdEx += 8`)
}
func (p *unmarshal) declareMapField(varName string, nullable bool, field *protogen.Field) {
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
p.P(`var `, varName, ` float64`)
case protoreflect.FloatKind:
p.P(`var `, varName, ` float32`)
case protoreflect.Int64Kind:
p.P(`var `, varName, ` int64`)
case protoreflect.Uint64Kind:
p.P(`var `, varName, ` uint64`)
case protoreflect.Int32Kind:
p.P(`var `, varName, ` int32`)
case protoreflect.Fixed64Kind:
p.P(`var `, varName, ` uint64`)
case protoreflect.Fixed32Kind:
p.P(`var `, varName, ` uint32`)
case protoreflect.BoolKind:
p.P(`var `, varName, ` bool`)
case protoreflect.StringKind:
p.P(`var `, varName, ` `, field.GoIdent)
case protoreflect.MessageKind:
msgname := field.GoIdent
if nullable {
p.P(`var `, varName, ` *`, msgname)
} else {
p.P(varName, ` := &`, msgname, `{}`)
}
case protoreflect.BytesKind:
p.P(varName, ` := []byte{}`)
case protoreflect.Uint32Kind:
p.P(`var `, varName, ` uint32`)
case protoreflect.EnumKind:
p.P(`var `, varName, ` `, field.GoIdent)
case protoreflect.Sfixed32Kind:
p.P(`var `, varName, ` int32`)
case protoreflect.Sfixed64Kind:
p.P(`var `, varName, ` int64`)
case protoreflect.Sint32Kind:
p.P(`var `, varName, ` int32`)
case protoreflect.Sint64Kind:
p.P(`var `, varName, ` int64`)
}
}
func (p *unmarshal) mapField(varName string, field *protogen.Field) {
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
p.P(`var `, varName, `temp uint64`)
p.decodeFixed64(varName+"temp", "uint64")
p.P(varName, ` = `, p.Ident("math", "Float64frombits"), `(`, varName, `temp)`)
case protoreflect.FloatKind:
p.P(`var `, varName, `temp uint32`)
p.decodeFixed32(varName+"temp", "uint32")
p.P(varName, ` = `, p.Ident("math", "Float32frombits"), `(`, varName, `temp)`)
case protoreflect.Int64Kind:
p.decodeVarint(varName, "int64")
case protoreflect.Uint64Kind:
p.decodeVarint(varName, "uint64")
case protoreflect.Int32Kind:
p.decodeVarint(varName, "int32")
case protoreflect.Fixed64Kind:
p.decodeFixed64(varName, "uint64")
case protoreflect.Fixed32Kind:
p.decodeFixed32(varName, "uint32")
case protoreflect.BoolKind:
p.P(`var `, varName, `temp int`)
p.decodeVarint(varName+"temp", "int")
p.P(varName, ` = bool(`, varName, `temp != 0)`)
case protoreflect.StringKind:
p.P(`var stringLen`, varName, ` uint64`)
p.decodeVarint("stringLen"+varName, "uint64")
p.P(`intStringLen`, varName, ` := int(stringLen`, varName, `)`)
p.P(`if intStringLen`, varName, ` < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postStringIndex`, varName, ` := iNdEx + intStringLen`, varName)
p.P(`if postStringIndex`, varName, ` < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postStringIndex`, varName, ` > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if p.unsafe {
p.P(`if intStringLen`, varName, ` == 0 {`)
p.P(varName, ` = ""`)
p.P(`} else {`)
p.P(varName, ` = `, p.Ident("unsafe", `String`), `(&dAtA[iNdEx], intStringLen`, varName, `)`)
p.P(`}`)
} else {
p.P(varName, ` = `, "string", `(dAtA[iNdEx:postStringIndex`, varName, `])`)
}
p.P(`iNdEx = postStringIndex`, varName)
case protoreflect.MessageKind:
p.P(`var mapmsglen int`)
p.decodeVarint("mapmsglen", "int")
p.P(`if mapmsglen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postmsgIndex := iNdEx + mapmsglen`)
p.P(`if postmsgIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postmsgIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
buf := `dAtA[iNdEx:postmsgIndex]`
p.P(varName, ` = &`, p.noStarOrSliceType(field), `{}`)
p.decodeMessage(varName, buf, field.Message)
p.P(`iNdEx = postmsgIndex`)
case protoreflect.BytesKind:
p.P(`var mapbyteLen uint64`)
p.decodeVarint("mapbyteLen", "uint64")
p.P(`intMapbyteLen := int(mapbyteLen)`)
p.P(`if intMapbyteLen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postbytesIndex := iNdEx + intMapbyteLen`)
p.P(`if postbytesIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postbytesIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if p.unsafe {
p.P(varName, ` = dAtA[iNdEx:postbytesIndex]`)
} else {
p.P(varName, ` = make([]byte, mapbyteLen)`)
p.P(`copy(`, varName, `, dAtA[iNdEx:postbytesIndex])`)
}
p.P(`iNdEx = postbytesIndex`)
case protoreflect.Uint32Kind:
p.decodeVarint(varName, "uint32")
case protoreflect.EnumKind:
goTypV, _ := p.FieldGoType(field)
p.decodeVarint(varName, goTypV)
case protoreflect.Sfixed32Kind:
p.decodeFixed32(varName, "int32")
case protoreflect.Sfixed64Kind:
p.decodeFixed64(varName, "int64")
case protoreflect.Sint32Kind:
p.P(`var `, varName, `temp int32`)
p.decodeVarint(varName+"temp", "int32")
p.P(varName, `temp = int32((uint32(`, varName, `temp) >> 1) ^ uint32(((`, varName, `temp&1)<<31)>>31))`)
p.P(varName, ` = int32(`, varName, `temp)`)
case protoreflect.Sint64Kind:
p.P(`var `, varName, `temp uint64`)
p.decodeVarint(varName+"temp", "uint64")
p.P(varName, `temp = (`, varName, `temp >> 1) ^ uint64((int64(`, varName, `temp&1)<<63)>>63)`)
p.P(varName, ` = int64(`, varName, `temp)`)
}
}
func (p *unmarshal) noStarOrSliceType(field *protogen.Field) string {
typ, _ := p.FieldGoType(field)
if typ[0] == '[' && typ[1] == ']' {
typ = typ[2:]
}
if typ[0] == '*' {
typ = typ[1:]
}
return typ
}
func (p *unmarshal) fieldItem(field *protogen.Field, fieldname string, message *protogen.Message, proto3 bool) {
repeated := field.Desc.Cardinality() == protoreflect.Repeated
typ := p.noStarOrSliceType(field)
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
nullable := field.Oneof != nil && field.Oneof.Desc.IsSynthetic()
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
p.P(`var v uint64`)
p.decodeFixed64("v", "uint64")
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", typ, "(", p.Ident("math", `Float64frombits`), `(v))}`)
} else if repeated {
p.P(`v2 := `, typ, "(", p.Ident("math", "Float64frombits"), `(v))`)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v2)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, typ, "(", p.Ident("math", "Float64frombits"), `(v))`)
} else {
p.P(`v2 := `, typ, "(", p.Ident("math", "Float64frombits"), `(v))`)
p.P(`m.`, fieldname, ` = &v2`)
}
case protoreflect.FloatKind:
p.P(`var v uint32`)
p.decodeFixed32("v", "uint32")
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", typ, "(", p.Ident("math", "Float32frombits"), `(v))}`)
} else if repeated {
p.P(`v2 := `, typ, "(", p.Ident("math", "Float32frombits"), `(v))`)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v2)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, typ, "(", p.Ident("math", "Float32frombits"), `(v))`)
} else {
p.P(`v2 := `, typ, "(", p.Ident("math", "Float32frombits"), `(v))`)
p.P(`m.`, fieldname, ` = &v2`)
}
case protoreflect.Int64Kind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Uint64Kind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Int32Kind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Fixed64Kind:
if oneof {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeFixed64("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Fixed32Kind:
if oneof {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeFixed32("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.BoolKind:
p.P(`var v int`)
p.decodeVarint("v", "int")
if oneof {
p.P(`b := `, typ, `(v != 0)`)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: b}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(v != 0))`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, typ, `(v != 0)`)
} else {
p.P(`b := `, typ, `(v != 0)`)
p.P(`m.`, fieldname, ` = &b`)
}
case protoreflect.StringKind:
p.P(`var stringLen uint64`)
p.decodeVarint("stringLen", "uint64")
p.P(`intStringLen := int(stringLen)`)
p.P(`if intStringLen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postIndex := iNdEx + intStringLen`)
p.P(`if postIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
str := "string(dAtA[iNdEx:postIndex])"
if p.unsafe {
str = "stringValue"
p.P(`var stringValue string`)
p.P(`if intStringLen > 0 {`)
p.P(`stringValue = `, p.Ident("unsafe", `String`), `(&dAtA[iNdEx], intStringLen)`)
p.P(`}`)
}
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", str, `}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, str, `)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, str)
} else {
p.P(`s := `, str)
p.P(`m.`, fieldname, ` = &s`)
}
p.P(`iNdEx = postIndex`)
case protoreflect.GroupKind:
p.P(`groupStart := iNdEx`)
p.P(`for {`)
p.P(`maybeGroupEnd := iNdEx`)
p.P(`var groupFieldWire uint64`)
p.decodeVarint("groupFieldWire", "uint64")
p.P(`groupWireType := int(wire & 0x7)`)
p.P(`if groupWireType == `, strconv.Itoa(int(protowire.EndGroupType)), `{`)
p.decodeMessage("m."+fieldname, "dAtA[groupStart:maybeGroupEnd]", field.Message)
p.P(`break`)
p.P(`}`)
p.P(`skippy, err := `, p.Helper("Skip"), `(dAtA[iNdEx:])`)
p.P(`if err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`if (skippy < 0) || (iNdEx + skippy) < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`iNdEx += skippy`)
p.P(`}`)
case protoreflect.MessageKind:
p.P(`var msglen int`)
p.decodeVarint("msglen", "int")
p.P(`if msglen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postIndex := iNdEx + msglen`)
p.P(`if postIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if oneof {
buf := `dAtA[iNdEx:postIndex]`
msgname := p.noStarOrSliceType(field)
p.P(`if oneof, ok := m.`, fieldname, `.(*`, field.GoIdent, `); ok {`)
p.decodeMessage("oneof."+field.GoName, buf, field.Message)
p.P(`} else {`)
if p.ShouldPool(message) && p.ShouldPool(field.Message) {
p.P(`v := `, msgname, `FromVTPool()`)
} else {
p.P(`v := &`, msgname, `{}`)
}
p.decodeMessage("v", buf, field.Message)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
p.P(`}`)
} else if field.Desc.IsMap() {
goTyp, _ := p.FieldGoType(field)
goTypK, _ := p.FieldGoType(field.Message.Fields[0])
goTypV, _ := p.FieldGoType(field.Message.Fields[1])
p.P(`if m.`, fieldname, ` == nil {`)
p.P(`m.`, fieldname, ` = make(`, goTyp, `)`)
p.P(`}`)
p.P("var mapkey ", goTypK)
p.P("var mapvalue ", goTypV)
p.P(`for iNdEx < postIndex {`)
p.P(`entryPreIndex := iNdEx`)
p.P(`var wire uint64`)
p.decodeVarint("wire", "uint64")
p.P(`fieldNum := int32(wire >> 3)`)
p.P(`if fieldNum == 1 {`)
p.mapField("mapkey", field.Message.Fields[0])
p.P(`} else if fieldNum == 2 {`)
p.mapField("mapvalue", field.Message.Fields[1])
p.P(`} else {`)
p.P(`iNdEx = entryPreIndex`)
p.P(`skippy, err := `, p.Helper("Skip"), `(dAtA[iNdEx:])`)
p.P(`if err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`if (skippy < 0) || (iNdEx + skippy) < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if (iNdEx + skippy) > postIndex {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(`iNdEx += skippy`)
p.P(`}`)
p.P(`}`)
p.P(`m.`, fieldname, `[mapkey] = mapvalue`)
} else if repeated {
if p.ShouldPool(message) {
p.P(`if len(m.`, fieldname, `) == cap(m.`, fieldname, `) {`)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, &`, field.Message.GoIdent, `{})`)
p.P(`} else {`)
p.P(`m.`, fieldname, ` = m.`, fieldname, `[:len(m.`, fieldname, `) + 1]`)
p.P(`if m.`, fieldname, `[len(m.`, fieldname, `) - 1] == nil {`)
p.P(`m.`, fieldname, `[len(m.`, fieldname, `) - 1] = &`, field.Message.GoIdent, `{}`)
p.P(`}`)
p.P(`}`)
} else {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, &`, field.Message.GoIdent, `{})`)
}
varname := fmt.Sprintf("m.%s[len(m.%s) - 1]", fieldname, fieldname)
buf := `dAtA[iNdEx:postIndex]`
p.decodeMessage(varname, buf, field.Message)
} else {
p.P(`if m.`, fieldname, ` == nil {`)
if p.ShouldPool(message) && p.ShouldPool(field.Message) {
p.P(`m.`, fieldname, ` = `, field.Message.GoIdent, `FromVTPool()`)
} else {
p.P(`m.`, fieldname, ` = &`, field.Message.GoIdent, `{}`)
}
p.P(`}`)
p.decodeMessage("m."+fieldname, "dAtA[iNdEx:postIndex]", field.Message)
}
p.P(`iNdEx = postIndex`)
case protoreflect.BytesKind:
p.P(`var byteLen int`)
p.decodeVarint("byteLen", "int")
p.P(`if byteLen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postIndex := iNdEx + byteLen`)
p.P(`if postIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if oneof {
if p.unsafe {
p.P(`v := dAtA[iNdEx:postIndex]`)
} else {
p.P(`v := make([]byte, postIndex-iNdEx)`)
p.P(`copy(v, dAtA[iNdEx:postIndex])`)
}
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
if p.unsafe {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, dAtA[iNdEx:postIndex])`)
} else {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, make([]byte, postIndex-iNdEx))`)
p.P(`copy(m.`, fieldname, `[len(m.`, fieldname, `)-1], dAtA[iNdEx:postIndex])`)
}
} else {
if p.unsafe {
p.P(`m.`, fieldname, ` = dAtA[iNdEx:postIndex]`)
} else {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `[:0] , dAtA[iNdEx:postIndex]...)`)
p.P(`if m.`, fieldname, ` == nil {`)
p.P(`m.`, fieldname, ` = []byte{}`)
p.P(`}`)
}
}
p.P(`iNdEx = postIndex`)
case protoreflect.Uint32Kind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.EnumKind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Sfixed32Kind:
if oneof {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeFixed32("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Sfixed64Kind:
if oneof {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeFixed64("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Sint32Kind:
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`v = `, typ, `((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31))`)
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = v`)
} else {
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Sint64Kind:
p.P(`var v uint64`)
p.decodeVarint("v", "uint64")
p.P(`v = (v >> 1) ^ uint64((int64(v&1)<<63)>>63)`)
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", typ, `(v)}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(v))`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, typ, `(v)`)
} else {
p.P(`v2 := `, typ, `(v)`)
p.P(`m.`, fieldname, ` = &v2`)
}
default:
panic("not implemented")
}
}
func (p *unmarshal) field(proto3, oneof bool, field *protogen.Field, message *protogen.Message, required protoreflect.FieldNumbers) {
fieldname := field.GoName
errFieldname := fieldname
if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
fieldname = field.Oneof.GoName
}
p.P(`case `, strconv.Itoa(int(field.Desc.Number())), `:`)
wireType := generator.ProtoWireType(field.Desc.Kind())
if field.Desc.IsList() && wireType != protowire.BytesType {
p.P(`if wireType == `, strconv.Itoa(int(wireType)), `{`)
p.fieldItem(field, fieldname, message, false)
p.P(`} else if wireType == `, strconv.Itoa(int(protowire.BytesType)), `{`)
p.P(`var packedLen int`)
p.decodeVarint("packedLen", "int")
p.P(`if packedLen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postIndex := iNdEx + packedLen`)
p.P(`if postIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postIndex > l {`)
p.P(`return `, p.Ident("io", "ErrUnexpectedEOF"))
p.P(`}`)
p.P(`var elementCount int`)
switch field.Desc.Kind() {
case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
p.P(`elementCount = packedLen/`, 8)
case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
p.P(`elementCount = packedLen/`, 4)
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.Sint32Kind, protoreflect.Sint64Kind:
p.P(`var count int`)
p.P(`for _, integer := range dAtA[iNdEx:postIndex] {`)
p.P(`if integer < 128 {`)
p.P(`count++`)
p.P(`}`)
p.P(`}`)
p.P(`elementCount = count`)
case protoreflect.BoolKind:
p.P(`elementCount = packedLen`)
}
if p.ShouldPool(message) {
p.P(`if elementCount != 0 && len(m.`, fieldname, `) == 0 && cap(m.`, fieldname, `) < elementCount {`)
} else {
p.P(`if elementCount != 0 && len(m.`, fieldname, `) == 0 {`)
}
fieldtyp, _ := p.FieldGoType(field)
p.P(`m.`, fieldname, ` = make(`, fieldtyp, `, 0, elementCount)`)
p.P(`}`)
p.P(`for iNdEx < postIndex {`)
p.fieldItem(field, fieldname, message, false)
p.P(`}`)
p.P(`} else {`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: wrong wireType = %d for field `, errFieldname, `", wireType)`)
p.P(`}`)
} else {
p.P(`if wireType != `, strconv.Itoa(int(wireType)), `{`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: wrong wireType = %d for field `, errFieldname, `", wireType)`)
p.P(`}`)
p.fieldItem(field, fieldname, message, proto3)
}
if field.Desc.Cardinality() == protoreflect.Required {
var fieldBit int
for fieldBit = 0; fieldBit < required.Len(); fieldBit++ {
if required.Get(fieldBit) == field.Desc.Number() {
break
}
}
if fieldBit == required.Len() {
panic("missing required field")
}
p.P(`hasFields[`, strconv.Itoa(fieldBit/64), `] |= uint64(`, fmt.Sprintf("0x%08x", uint64(1)<<(fieldBit%64)), `)`)
}
}
func (p *unmarshal) message(proto3 bool, message *protogen.Message) {
for _, nested := range message.Messages {
p.message(proto3, nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
ccTypeName := message.GoIdent.GoName
required := message.Desc.RequiredNumbers()
p.P(`func (m *`, ccTypeName, `) `, p.methodUnmarshal(), `(dAtA []byte) error {`)
if required.Len() > 0 {
p.P(`var hasFields [`, strconv.Itoa(1+(required.Len()-1)/64), `]uint64`)
}
p.P(`l := len(dAtA)`)
p.P(`iNdEx := 0`)
p.P(`for iNdEx < l {`)
p.P(`preIndex := iNdEx`)
p.P(`var wire uint64`)
p.decodeVarint("wire", "uint64")
p.P(`fieldNum := int32(wire >> 3)`)
p.P(`wireType := int(wire & 0x7)`)
p.P(`if wireType == `, strconv.Itoa(int(protowire.EndGroupType)), ` {`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: `, message.GoIdent.GoName, `: wiretype end group for non-group")`)
p.P(`}`)
p.P(`if fieldNum <= 0 {`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: `, message.GoIdent.GoName, `: illegal tag %d (wire type %d)", fieldNum, wire)`)
p.P(`}`)
p.P(`switch fieldNum {`)
for _, field := range message.Fields {
p.field(proto3, false, field, message, required)
}
p.P(`default:`)
p.P(`iNdEx=preIndex`)
p.P(`skippy, err := `, p.Helper("Skip"), `(dAtA[iNdEx:])`)
p.P(`if err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`if (skippy < 0) || (iNdEx + skippy) < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if (iNdEx + skippy) > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if message.Desc.ExtensionRanges().Len() > 0 {
c := []string{}
eranges := message.Desc.ExtensionRanges()
for e := 0; e < eranges.Len(); e++ {
erange := eranges.Get(e)
c = append(c, `((fieldNum >= `+strconv.Itoa(int(erange[0]))+`) && (fieldNum < `+strconv.Itoa(int(erange[1]))+`))`)
}
p.P(`if `, strings.Join(c, "||"), `{`)
p.P(`err = `, p.Ident(generator.ProtoPkg, "UnmarshalOptions"), `{AllowPartial: true}.Unmarshal(dAtA[iNdEx:iNdEx+skippy], m)`)
p.P(`if err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`iNdEx += skippy`)
p.P(`} else {`)
}
if !p.Wrapper() {
p.P(`m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...)`)
}
p.P(`iNdEx += skippy`)
if message.Desc.ExtensionRanges().Len() > 0 {
p.P(`}`)
}
p.P(`}`)
p.P(`}`)
for _, field := range message.Fields {
if field.Desc.Cardinality() != protoreflect.Required {
continue
}
var fieldBit int
for fieldBit = 0; fieldBit < required.Len(); fieldBit++ {
if required.Get(fieldBit) == field.Desc.Number() {
break
}
}
if fieldBit == required.Len() {
panic("missing required field")
}
p.P(`if hasFields[`, strconv.Itoa(int(fieldBit/64)), `] & uint64(`, fmt.Sprintf("0x%08x", uint64(1)<<(fieldBit%64)), `) == 0 {`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: required field `, field.Desc.Name(), ` not set")`)
p.P(`}`)
}
p.P()
p.P(`if iNdEx > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(`return nil`)
p.P(`}`)
}