hack: generate vtproto files for buildx

Integrates vtproto into buildx. The generated files dockerfile has been
modified to copy the buildkit equivalent file to ensure files are laid
out in the appropriate way for imports.

An import has also been included to change the grpc codec to the version
in buildkit that supports vtproto. This will allow buildx to utilize the
speed and memory improvements from that.

Also updates the gc control options for prune.

Signed-off-by: Jonathan A. Sternberg <jonathan.sternberg@docker.com>
This commit is contained in:
Jonathan A. Sternberg
2024-10-08 13:35:06 -05:00
parent d353f5f6ba
commit 64c5139ab6
109 changed files with 68070 additions and 2941 deletions

29
vendor/github.com/planetscale/vtprotobuf/LICENSE generated vendored Normal file
View File

@ -0,0 +1,29 @@
Copyright (c) 2021, PlanetScale Inc. All rights reserved.
Copyright (c) 2013, The GoGo Authors. All rights reserved.
Copyright (c) 2018 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,41 @@
package main
import (
"flag"
"strings"
_ "github.com/planetscale/vtprotobuf/features/clone"
_ "github.com/planetscale/vtprotobuf/features/equal"
_ "github.com/planetscale/vtprotobuf/features/grpc"
_ "github.com/planetscale/vtprotobuf/features/marshal"
_ "github.com/planetscale/vtprotobuf/features/pool"
_ "github.com/planetscale/vtprotobuf/features/size"
_ "github.com/planetscale/vtprotobuf/features/unmarshal"
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
)
func main() {
var cfg generator.Config
var features string
var f flag.FlagSet
f.BoolVar(&cfg.AllowEmpty, "allow-empty", false, "allow generation of empty files")
cfg.Poolable = generator.NewObjectSet()
cfg.PoolableExclude = generator.NewObjectSet()
f.Var(&cfg.Poolable, "pool", "use memory pooling for this object")
f.Var(&cfg.PoolableExclude, "pool-exclude", "do not use memory pooling for this object")
f.BoolVar(&cfg.Wrap, "wrap", false, "generate wrapper types")
f.StringVar(&features, "features", "all", "list of features to generate (separated by '+')")
f.StringVar(&cfg.BuildTag, "buildTag", "", "the go:build tag to set on generated files")
protogen.Options{ParamFunc: f.Set}.Run(func(plugin *protogen.Plugin) error {
gen, err := generator.NewGenerator(plugin, strings.Split(features, "+"), &cfg)
if err != nil {
return err
}
gen.Generate()
return nil
})
}

View File

@ -0,0 +1,338 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package clone
import (
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
const (
cloneName = "CloneVT"
cloneMessageName = "CloneMessageVT"
)
var (
protoPkg = protogen.GoImportPath("google.golang.org/protobuf/proto")
)
func init() {
generator.RegisterFeature("clone", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &clone{GeneratedFile: gen}
})
}
type clone struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*clone)(nil)
func (p *clone) Name() string {
return "clone"
}
func (p *clone) GenerateFile(file *protogen.File) bool {
proto3 := file.Desc.Syntax() == protoreflect.Proto3
for _, message := range file.Messages {
p.processMessage(proto3, message)
}
return p.once
}
// cloneOneofField generates the statements for cloning a oneof field
func (p *clone) cloneOneofField(lhsBase, rhsBase string, oneof *protogen.Oneof) {
fieldname := oneof.GoName
ccInterfaceName := "is" + oneof.GoIdent.GoName
lhs := lhsBase + "." + fieldname
rhs := rhsBase + "." + fieldname
p.P(`if `, rhs, ` != nil {`)
if p.IsWellKnownType(oneof.Parent) {
p.P(`switch c := `, rhs, `.(type) {`)
for _, f := range oneof.Fields {
p.P(`case *`, f.GoIdent, `:`)
p.P(lhs, `= (*`, f.GoIdent, `)((*`, p.WellKnownFieldMap(f), `)(c).`, cloneName, `())`)
}
p.P(`}`)
} else {
p.P(lhs, ` = `, rhs, `.(interface{ `, cloneName, `() `, ccInterfaceName, ` }).`, cloneName, `()`)
}
p.P(`}`)
}
// cloneFieldSingular generates the code for cloning a singular, non-oneof field.
func (p *clone) cloneFieldSingular(lhs, rhs string, kind protoreflect.Kind, message *protogen.Message) {
switch {
case kind == protoreflect.MessageKind, kind == protoreflect.GroupKind:
switch {
case p.IsWellKnownType(message):
p.P(lhs, ` = (*`, message.GoIdent, `)((*`, p.WellKnownTypeMap(message), `)(`, rhs, `).`, cloneName, `())`)
case p.IsLocalMessage(message):
p.P(lhs, ` = `, rhs, `.`, cloneName, `()`)
default:
// rhs is a concrete type, we need to first convert it to an interface in order to use an interface
// type assertion.
p.P(`if vtpb, ok := interface{}(`, rhs, `).(interface{ `, cloneName, `() *`, message.GoIdent, ` }); ok {`)
p.P(lhs, ` = vtpb.`, cloneName, `()`)
p.P(`} else {`)
p.P(lhs, ` = `, protoPkg.Ident("Clone"), `(`, rhs, `).(*`, message.GoIdent, `)`)
p.P(`}`)
}
case kind == protoreflect.BytesKind:
p.P(`tmpBytes := make([]byte, len(`, rhs, `))`)
p.P(`copy(tmpBytes, `, rhs, `)`)
p.P(lhs, ` = tmpBytes`)
case isScalar(kind):
p.P(lhs, ` = `, rhs)
default:
panic("unexpected")
}
}
// cloneField generates the code for cloning a field in a protobuf.
func (p *clone) cloneField(lhsBase, rhsBase string, allFieldsNullable bool, field *protogen.Field) {
// At this point, if we encounter a non-synthetic oneof, we assume it to be the representative
// field for that oneof.
if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
p.cloneOneofField(lhsBase, rhsBase, field.Oneof)
return
}
if !isReference(allFieldsNullable, field) {
panic("method should not be invoked for non-reference fields")
}
fieldname := field.GoName
lhs := lhsBase + "." + fieldname
rhs := rhsBase + "." + fieldname
// At this point, we are only looking at reference types (pointers, maps, slices, interfaces), which can all
// be nil.
p.P(`if rhs := `, rhs, `; rhs != nil {`)
rhs = "rhs"
fieldKind := field.Desc.Kind()
msg := field.Message // possibly nil
if field.Desc.Cardinality() == protoreflect.Repeated { // maps and slices
goType, _ := p.FieldGoType(field)
p.P(`tmpContainer := make(`, goType, `, len(`, rhs, `))`)
if isScalar(fieldKind) && field.Desc.IsList() {
// Generated code optimization: instead of iterating over all (key/index, value) pairs,
// do a single copy(dst, src) invocation for slices whose elements aren't reference types.
p.P(`copy(tmpContainer, `, rhs, `)`)
} else {
if field.Desc.IsMap() {
// For maps, the type of the value field determines what code is generated for cloning
// an entry.
valueField := field.Message.Fields[1]
fieldKind = valueField.Desc.Kind()
msg = valueField.Message
}
p.P(`for k, v := range `, rhs, ` {`)
p.cloneFieldSingular("tmpContainer[k]", "v", fieldKind, msg)
p.P(`}`)
}
p.P(lhs, ` = tmpContainer`)
} else if isScalar(fieldKind) {
p.P(`tmpVal := *`, rhs)
p.P(lhs, ` = &tmpVal`)
} else {
p.cloneFieldSingular(lhs, rhs, fieldKind, msg)
}
p.P(`}`)
}
func (p *clone) generateCloneMethodsForMessage(proto3 bool, message *protogen.Message) {
ccTypeName := message.GoIdent.GoName
p.P(`func (m *`, ccTypeName, `) `, cloneName, `() *`, ccTypeName, ` {`)
p.body(!proto3, ccTypeName, message, true)
p.P(`}`)
p.P()
if !p.Wrapper() {
p.P(`func (m *`, ccTypeName, `) `, cloneMessageName, `() `, protoPkg.Ident("Message"), ` {`)
p.P(`return m.`, cloneName, `()`)
p.P(`}`)
p.P()
}
}
// body generates the code for the actual cloning logic of a structure containing the given fields.
// In practice, those can be the fields of a message.
// The object to be cloned is assumed to be called "m".
func (p *clone) body(allFieldsNullable bool, ccTypeName string, message *protogen.Message, cloneUnknownFields bool) {
// The method body for a message or a oneof wrapper always starts with a nil check.
p.P(`if m == nil {`)
// We use an explicitly typed nil to avoid returning the nil interface in the oneof wrapper
// case.
p.P(`return (*`, ccTypeName, `)(nil)`)
p.P(`}`)
fields := message.Fields
// Make a first pass over the fields, in which we initialize all non-reference fields via direct
// struct literal initialization, and extract all other (reference) fields for a second pass.
// Do not require qualified name because CloneVT generates in same file with definition.
p.Alloc("r", message, false)
var refFields []*protogen.Field
oneofFields := make(map[string]struct{}, len(fields))
for _, field := range fields {
if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
// Use the first field in a oneof as the representative for that oneof, disregard
// the other fields in that oneof.
if _, ok := oneofFields[field.Oneof.GoName]; !ok {
refFields = append(refFields, field)
oneofFields[field.Oneof.GoName] = struct{}{}
}
continue
}
if !isReference(allFieldsNullable, field) {
p.P(`r.`, field.GoName, ` = m.`, field.GoName)
continue
}
// Shortcut: for types where we know that an optimized clone method exists, we can call it directly as it is
// nil-safe.
if field.Desc.Cardinality() != protoreflect.Repeated {
switch {
case p.IsWellKnownType(field.Message):
p.P(`r.`, field.GoName, ` = (*`, field.Message.GoIdent, `)((*`, p.WellKnownTypeMap(field.Message), `)(m.`, field.GoName, `).`, cloneName, `())`)
continue
case p.IsLocalMessage(field.Message):
p.P(`r.`, field.GoName, ` = m.`, field.GoName, `.`, cloneName, `()`)
continue
}
}
refFields = append(refFields, field)
}
// Generate explicit assignment statements for all reference fields.
for _, field := range refFields {
p.cloneField("r", "m", allFieldsNullable, field)
}
if cloneUnknownFields && !p.Wrapper() {
// Clone unknown fields, if any
p.P(`if len(m.unknownFields) > 0 {`)
p.P(`r.unknownFields = make([]byte, len(m.unknownFields))`)
p.P(`copy(r.unknownFields, m.unknownFields)`)
p.P(`}`)
}
p.P(`return r`)
}
func (p *clone) bodyForOneOf(ccTypeName string, field *protogen.Field) {
// The method body for a message or a oneof wrapper always starts with a nil check.
p.P(`if m == nil {`)
// We use an explicitly typed nil to avoid returning the nil interface in the oneof wrapper
// case.
p.P(`return (*`, ccTypeName, `)(nil)`)
p.P(`}`)
p.P("r", " := new(", ccTypeName, `)`)
if !isReference(false, field) {
p.P(`r.`, field.GoName, ` = m.`, field.GoName)
p.P(`return r`)
return
}
// Shortcut: for types where we know that an optimized clone method exists, we can call it directly as it is
// nil-safe.
if field.Desc.Cardinality() != protoreflect.Repeated && field.Message != nil {
switch {
case p.IsWellKnownType(field.Message):
p.P(`r.`, field.GoName, ` = (*`, field.Message.GoIdent, `)((*`, p.WellKnownTypeMap(field.Message), `)(m.`, field.GoName, `).`, cloneName, `())`)
p.P(`return r`)
return
case p.IsLocalMessage(field.Message):
p.P(`r.`, field.GoName, ` = m.`, field.GoName, `.`, cloneName, `()`)
p.P(`return r`)
return
}
}
// Generate explicit assignment statements for reference field.
p.cloneField("r", "m", false, field)
p.P(`return r`)
}
// generateCloneMethodsForOneof generates the clone method for the oneof wrapper type of a
// field in a oneof.
func (p *clone) generateCloneMethodsForOneof(message *protogen.Message, field *protogen.Field) {
ccTypeName := field.GoIdent.GoName
ccInterfaceName := "is" + field.Oneof.GoIdent.GoName
if p.IsWellKnownType(message) {
p.P(`func (m *`, ccTypeName, `) `, cloneName, `() *`, ccTypeName, ` {`)
} else {
p.P(`func (m *`, ccTypeName, `) `, cloneName, `() `, ccInterfaceName, ` {`)
}
// Create a "fake" field for the single oneof member, pretending it is not a oneof field.
fieldInOneof := *field
fieldInOneof.Oneof = nil
// If we have a scalar field in a oneof, that field is never nullable, even when using proto2
p.bodyForOneOf(ccTypeName, &fieldInOneof)
p.P(`}`)
p.P()
}
func (p *clone) processMessageOneofs(message *protogen.Message) {
for _, field := range message.Fields {
if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() {
continue
}
p.generateCloneMethodsForOneof(message, field)
}
}
func (p *clone) processMessage(proto3 bool, message *protogen.Message) {
for _, nested := range message.Messages {
p.processMessage(proto3, nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
p.generateCloneMethodsForMessage(proto3, message)
p.processMessageOneofs(message)
}
// isReference checks whether the Go equivalent of the given field is of reference type, i.e., can be nil.
func isReference(allFieldsNullable bool, field *protogen.Field) bool {
if allFieldsNullable || field.Oneof != nil || field.Message != nil || field.Desc.Cardinality() == protoreflect.Repeated || field.Desc.Kind() == protoreflect.BytesKind {
return true
}
if !isScalar(field.Desc.Kind()) {
panic("unexpected non-reference, non-scalar field")
}
return false
}
func isScalar(kind protoreflect.Kind) bool {
switch kind {
case
protoreflect.BoolKind,
protoreflect.StringKind,
protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind,
protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind,
protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Sint64Kind,
protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.Sint32Kind,
protoreflect.EnumKind:
return true
}
return false
}

View File

@ -0,0 +1,308 @@
// Copyright (c) 2022 PlanetScale Inc. All rights reserved.
package equal
import (
"fmt"
"sort"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
func init() {
generator.RegisterFeature("equal", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &equal{GeneratedFile: gen}
})
}
var (
protoPkg = protogen.GoImportPath("google.golang.org/protobuf/proto")
)
type equal struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*equal)(nil)
func (p *equal) Name() string { return "equal" }
func (p *equal) GenerateFile(file *protogen.File) bool {
proto3 := file.Desc.Syntax() == protoreflect.Proto3
for _, message := range file.Messages {
p.message(proto3, message)
}
return p.once
}
const equalName = "EqualVT"
const equalMessageName = "EqualMessageVT"
func (p *equal) message(proto3 bool, message *protogen.Message) {
for _, nested := range message.Messages {
p.message(proto3, nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
ccTypeName := message.GoIdent.GoName
p.P(`func (this *`, ccTypeName, `) `, equalName, `(that *`, ccTypeName, `) bool {`)
p.P(`if this == that {`)
p.P(` return true`)
p.P(`} else if this == nil || that == nil {`)
p.P(` return false`)
p.P(`}`)
sort.Slice(message.Fields, func(i, j int) bool {
return message.Fields[i].Desc.Number() < message.Fields[j].Desc.Number()
})
{
oneofs := make(map[string]struct{}, len(message.Fields))
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
continue
}
fieldname := field.Oneof.GoName
if _, ok := oneofs[fieldname]; ok {
continue
}
oneofs[fieldname] = struct{}{}
p.P(`if this.`, fieldname, ` == nil && that.`, fieldname, ` != nil {`)
p.P(` return false`)
p.P(`} else if this.`, fieldname, ` != nil {`)
p.P(` if that.`, fieldname, ` == nil {`)
p.P(` return false`)
p.P(` }`)
ccInterfaceName := fmt.Sprintf("is%s", field.Oneof.GoIdent.GoName)
if p.IsWellKnownType(message) {
p.P(`switch c := this.`, fieldname, `.(type) {`)
for _, f := range field.Oneof.Fields {
p.P(`case *`, f.GoIdent, `:`)
p.P(`if !(*`, p.WellKnownFieldMap(f), `)(c).`, equalName, `(that.`, fieldname, `) {`)
p.P(`return false`)
p.P(`}`)
}
p.P(`}`)
} else {
p.P(`if !this.`, fieldname, `.(interface{ `, equalName, `(`, ccInterfaceName, `) bool }).`, equalName, `(that.`, fieldname, `) {`)
p.P(`return false`)
p.P(`}`)
}
p.P(`}`)
}
}
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
nullable := field.Message != nil || (field.Oneof != nil && field.Oneof.Desc.IsSynthetic()) || (!proto3 && !oneof)
if !oneof {
p.field(field, nullable)
}
}
if p.Wrapper() {
p.P(`return true`)
} else {
p.P(`return string(this.unknownFields) == string(that.unknownFields)`)
}
p.P(`}`)
p.P()
if !p.Wrapper() {
p.P(`func (this *`, ccTypeName, `) `, equalMessageName, `(thatMsg `, protoPkg.Ident("Message"), `) bool {`)
p.P(`that, ok := thatMsg.(*`, ccTypeName, `)`)
p.P(`if !ok {`)
p.P(`return false`)
p.P(`}`)
p.P(`return this.`, equalName, `(that)`)
p.P(`}`)
}
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
continue
}
p.oneof(field)
}
}
func (p *equal) oneof(field *protogen.Field) {
ccTypeName := field.GoIdent.GoName
ccInterfaceName := fmt.Sprintf("is%s", field.Oneof.GoIdent.GoName)
fieldname := field.GoName
if p.IsWellKnownType(field.Parent) {
p.P(`func (this *`, ccTypeName, `) `, equalName, `(thatIface any) bool {`)
} else {
p.P(`func (this *`, ccTypeName, `) `, equalName, `(thatIface `, ccInterfaceName, `) bool {`)
}
p.P(`that, ok := thatIface.(*`, ccTypeName, `)`)
p.P(`if !ok {`)
if p.IsWellKnownType(field.Parent) {
p.P(`if ot, ok := thatIface.(*`, field.GoIdent, `); ok {`)
p.P(`that = (*`, ccTypeName, `)(ot)`)
p.P("} else {")
p.P("return false")
p.P("}")
} else {
p.P(`return false`)
}
p.P(`}`)
p.P(`if this == that {`)
p.P(`return true`)
p.P(`}`)
p.P(`if this == nil && that != nil || this != nil && that == nil {`)
p.P(`return false`)
p.P(`}`)
lhs := fmt.Sprintf("this.%s", fieldname)
rhs := fmt.Sprintf("that.%s", fieldname)
kind := field.Desc.Kind()
switch {
case isScalar(kind):
p.compareScalar(lhs, rhs, false)
case kind == protoreflect.BytesKind:
p.compareBytes(lhs, rhs, false)
case kind == protoreflect.MessageKind || kind == protoreflect.GroupKind:
p.compareCall(lhs, rhs, field.Message, false)
default:
panic("not implemented")
}
p.P(`return true`)
p.P(`}`)
p.P()
}
func (p *equal) field(field *protogen.Field, nullable bool) {
fieldname := field.GoName
repeated := field.Desc.Cardinality() == protoreflect.Repeated
lhs := fmt.Sprintf("this.%s", fieldname)
rhs := fmt.Sprintf("that.%s", fieldname)
if repeated {
p.P(`if len(`, lhs, `) != len(`, rhs, `) {`)
p.P(` return false`)
p.P(`}`)
p.P(`for i, vx := range `, lhs, ` {`)
if field.Desc.IsMap() {
p.P(`vy, ok := `, rhs, `[i]`)
p.P(`if !ok {`)
p.P(`return false`)
p.P(`}`)
field = field.Message.Fields[1]
} else {
p.P(`vy := `, rhs, `[i]`)
}
lhs, rhs = "vx", "vy"
nullable = false
}
kind := field.Desc.Kind()
switch {
case isScalar(kind):
p.compareScalar(lhs, rhs, nullable)
case kind == protoreflect.BytesKind:
p.compareBytes(lhs, rhs, nullable)
case kind == protoreflect.MessageKind || kind == protoreflect.GroupKind:
p.compareCall(lhs, rhs, field.Message, nullable)
default:
panic("not implemented")
}
if repeated {
// close for loop
p.P(`}`)
}
}
func (p *equal) compareScalar(lhs, rhs string, nullable bool) {
if nullable {
p.P(`if p, q := `, lhs, `, `, rhs, `; (p == nil && q != nil) || (p != nil && (q == nil || *p != *q)) {`)
} else {
p.P(`if `, lhs, ` != `, rhs, ` {`)
}
p.P(` return false`)
p.P(`}`)
}
func (p *equal) compareBytes(lhs, rhs string, nullable bool) {
if nullable {
p.P(`if p, q := `, lhs, `, `, rhs, `; (p == nil && q != nil) || (p != nil && q == nil) || string(p) != string(q) {`)
} else {
// Inlined call to bytes.Equal()
p.P(`if string(`, lhs, `) != string(`, rhs, `) {`)
}
p.P(` return false`)
p.P(`}`)
}
func (p *equal) compareCall(lhs, rhs string, msg *protogen.Message, nullable bool) {
if !nullable {
// The p != q check is mostly intended to catch the lhs = nil, rhs = nil case in which we would pointlessly
// allocate not just one but two empty values. However, it also provides us with an extra scope to establish
// our p and q variables.
p.P(`if p, q := `, lhs, `, `, rhs, `; p != q {`)
defer p.P(`}`)
p.P(`if p == nil {`)
p.P(`p = &`, p.QualifiedGoIdent(msg.GoIdent), `{}`)
p.P(`}`)
p.P(`if q == nil {`)
p.P(`q = &`, p.QualifiedGoIdent(msg.GoIdent), `{}`)
p.P(`}`)
lhs, rhs = "p", "q"
}
switch {
case p.IsWellKnownType(msg):
wkt := p.WellKnownTypeMap(msg)
p.P(`if !(*`, wkt, `)(`, lhs, `).`, equalName, `((*`, wkt, `)(`, rhs, `)) {`)
p.P(` return false`)
p.P(`}`)
case p.IsLocalMessage(msg):
p.P(`if !`, lhs, `.`, equalName, `(`, rhs, `) {`)
p.P(` return false`)
p.P(`}`)
default:
p.P(`if equal, ok := interface{}(`, lhs, `).(interface { `, equalName, `(*`, p.QualifiedGoIdent(msg.GoIdent), `) bool }); ok {`)
p.P(` if !equal.`, equalName, `(`, rhs, `) {`)
p.P(` return false`)
p.P(` }`)
p.P(`} else if !`, p.Ident("google.golang.org/protobuf/proto", "Equal"), `(`, lhs, `, `, rhs, `) {`)
p.P(` return false`)
p.P(`}`)
}
}
func isScalar(kind protoreflect.Kind) bool {
switch kind {
case
protoreflect.BoolKind,
protoreflect.StringKind,
protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind,
protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind,
protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Sint64Kind,
protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.Sint32Kind,
protoreflect.EnumKind:
return true
}
return false
}

View File

@ -0,0 +1,422 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"fmt"
"strconv"
"strings"
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/descriptorpb"
)
const (
contextPackage = protogen.GoImportPath("context")
grpcPackage = protogen.GoImportPath("google.golang.org/grpc")
codesPackage = protogen.GoImportPath("google.golang.org/grpc/codes")
statusPackage = protogen.GoImportPath("google.golang.org/grpc/status")
)
// generateFileContent generates the gRPC service definitions, excluding the package statement.
func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *generator.GeneratedFile) {
if len(file.Services) == 0 {
return
}
g.P("// This is a compile-time assertion to ensure that this generated file")
g.P("// is compatible with the grpc package it is being compiled against.")
g.P("// Requires gRPC-Go v1.32.0 or later.")
g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion7")) // When changing, update version number above.
g.P()
for _, service := range file.Services {
genService(gen, file, g, service)
}
}
func genService(gen *protogen.Plugin, file *protogen.File, g *generator.GeneratedFile, service *protogen.Service) {
clientName := service.GoName + "Client"
g.P("// ", clientName, " is the client API for ", service.GoName, " service.")
g.P("//")
g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.")
// Client interface.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
}
g.Annotate(clientName, service.Location)
g.P("type ", clientName, " interface {")
for _, method := range service.Methods {
g.Annotate(clientName+"."+method.GoName, method.Location)
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P(method.Comments.Leading,
clientSignature(g, method))
}
g.P("}")
g.P()
// Client structure.
g.P("type ", unexport(clientName), " struct {")
g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
g.P("}")
g.P()
// NewClient factory.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {")
g.P("return &", unexport(clientName), "{cc}")
g.P("}")
g.P()
var methodIndex, streamIndex int
// Client method implementations.
for _, method := range service.Methods {
if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
// Unary RPC method
genClientMethod(gen, file, g, method, methodIndex)
methodIndex++
} else {
// Streaming RPC method
genClientMethod(gen, file, g, method, streamIndex)
streamIndex++
}
}
mustOrShould := "must"
if !*requireUnimplemented {
mustOrShould = "should"
}
// Server interface.
serverType := service.GoName + "Server"
g.P("// ", serverType, " is the server API for ", service.GoName, " service.")
g.P("// All implementations ", mustOrShould, " embed Unimplemented", serverType)
g.P("// for forward compatibility")
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
}
g.Annotate(serverType, service.Location)
g.P("type ", serverType, " interface {")
for _, method := range service.Methods {
g.Annotate(serverType+"."+method.GoName, method.Location)
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P(method.Comments.Leading,
serverSignature(g, method))
}
if *requireUnimplemented {
g.P("mustEmbedUnimplemented", serverType, "()")
}
g.P("}")
g.P()
// Server Unimplemented struct for forward compatibility.
g.P("// Unimplemented", serverType, " ", mustOrShould, " be embedded to have forward compatible implementations.")
g.P("type Unimplemented", serverType, " struct {")
g.P("}")
g.P()
for _, method := range service.Methods {
nilArg := ""
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
nilArg = "nil,"
}
g.P("func (Unimplemented", serverType, ") ", serverSignature(g, method), "{")
g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`)
g.P("}")
}
if *requireUnimplemented {
g.P("func (Unimplemented", serverType, ") mustEmbedUnimplemented", serverType, "() {}")
}
g.P()
// Unsafe Server interface to opt-out of forward compatibility.
g.P("// Unsafe", serverType, " may be embedded to opt out of forward compatibility for this service.")
g.P("// Use of this interface is not recommended, as added methods to ", serverType, " will")
g.P("// result in compilation errors.")
g.P("type Unsafe", serverType, " interface {")
g.P("mustEmbedUnimplemented", serverType, "()")
g.P("}")
// Server registration.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P(deprecationComment)
}
serviceDescVar := service.GoName + "_ServiceDesc"
g.P("func Register", service.GoName, "Server(s ", grpcPackage.Ident("ServiceRegistrar"), ", srv ", serverType, ") {")
g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
g.P("}")
g.P()
// Server handler implementations.
var handlerNames []string
for _, method := range service.Methods {
hname := genServerMethod(gen, file, g, method)
handlerNames = append(handlerNames, hname)
}
// Service descriptor.
g.P("// ", serviceDescVar, " is the ", grpcPackage.Ident("ServiceDesc"), " for ", service.GoName, " service.")
g.P("// It's only intended for direct use with ", grpcPackage.Ident("RegisterService"), ",")
g.P("// and not to be introspected or modified (even as a copy)")
g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
g.P("HandlerType: (*", serverType, ")(nil),")
g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
for i, method := range service.Methods {
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
g.P("},")
}
g.P("},")
g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
for i, method := range service.Methods {
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
if method.Desc.IsStreamingServer() {
g.P("ServerStreams: true,")
}
if method.Desc.IsStreamingClient() {
g.P("ClientStreams: true,")
}
g.P("},")
}
g.P("},")
g.P("Metadata: \"", file.Desc.Path(), "\",")
g.P("}")
g.P()
}
func clientSignature(g *generator.GeneratedFile, method *protogen.Method) string {
s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
if !method.Desc.IsStreamingClient() {
s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent)
}
s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") ("
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
} else {
s += method.Parent.GoName + "_" + method.GoName + "Client"
}
s += ", error)"
return s
}
func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *generator.GeneratedFile, method *protogen.Method, index int) {
service := method.Parent
sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{")
if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
// g.P("out := new(", method.Output.GoIdent, ")")
g.Alloc("out", method.Output, true)
g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`)
g.P("if err != nil { return nil, err }")
g.P("return out, nil")
g.P("}")
g.P()
return
}
streamType := unexport(service.GoName) + method.GoName + "Client"
serviceDescVar := service.GoName + "_ServiceDesc"
g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`)
g.P("if err != nil { return nil, err }")
g.P("x := &", streamType, "{stream}")
if !method.Desc.IsStreamingClient() {
g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
}
g.P("return x, nil")
g.P("}")
g.P()
genSend := method.Desc.IsStreamingClient()
genRecv := method.Desc.IsStreamingServer()
genCloseAndRecv := !method.Desc.IsStreamingServer()
// Stream auxiliary types and methods.
g.P("type ", service.GoName, "_", method.GoName, "Client interface {")
if genSend {
g.P("Send(*", method.Input.GoIdent, ") error")
}
if genRecv {
g.P("Recv() (*", method.Output.GoIdent, ", error)")
}
if genCloseAndRecv {
g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
}
g.P(grpcPackage.Ident("ClientStream"))
g.P("}")
g.P()
g.P("type ", streamType, " struct {")
g.P(grpcPackage.Ident("ClientStream"))
g.P("}")
g.P()
if genSend {
g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
g.P("return x.ClientStream.SendMsg(m)")
g.P("}")
g.P()
}
if genRecv {
g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
// g.P("m := new(", method.Output.GoIdent, ")")
g.Alloc("m", method.Output, true)
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
if genCloseAndRecv {
g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
// g.P("m := new(", method.Output.GoIdent, ")")
g.Alloc("m", method.Output, true)
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
}
func serverSignature(g *generator.GeneratedFile, method *protogen.Method) string {
var reqArgs []string
ret := "error"
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context")))
ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)"
}
if !method.Desc.IsStreamingClient() {
reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent))
}
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
}
return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
}
func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *generator.GeneratedFile, method *protogen.Method) string {
service := method.Parent
hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
// g.P("in := new(", method.Input.GoIdent, ")")
g.Alloc("in", method.Input, true)
g.P("if err := dec(in); err != nil { return nil, err }")
g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{")
g.P("Server: srv,")
g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())), ",")
g.P("}")
g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {")
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))")
g.P("}")
g.P("return interceptor(ctx, in, info, handler)")
g.P("}")
g.P()
return hname
}
streamType := unexport(service.GoName) + method.GoName + "Server"
g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
if !method.Desc.IsStreamingClient() {
// g.P("m := new(", method.Input.GoIdent, ")")
g.Alloc("m", method.Input, true)
g.P("if err := stream.RecvMsg(m); err != nil { return err }")
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})")
} else {
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})")
}
g.P("}")
g.P()
genSend := method.Desc.IsStreamingServer()
genSendAndClose := !method.Desc.IsStreamingServer()
genRecv := method.Desc.IsStreamingClient()
// Stream auxiliary types and methods.
g.P("type ", service.GoName, "_", method.GoName, "Server interface {")
if genSend {
g.P("Send(*", method.Output.GoIdent, ") error")
}
if genSendAndClose {
g.P("SendAndClose(*", method.Output.GoIdent, ") error")
}
if genRecv {
g.P("Recv() (*", method.Input.GoIdent, ", error)")
}
g.P(grpcPackage.Ident("ServerStream"))
g.P("}")
g.P()
g.P("type ", streamType, " struct {")
g.P(grpcPackage.Ident("ServerStream"))
g.P("}")
g.P()
if genSend {
g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
g.P("return x.ServerStream.SendMsg(m)")
g.P("}")
g.P()
}
if genSendAndClose {
g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
g.P("return x.ServerStream.SendMsg(m)")
g.P("}")
g.P()
}
if genRecv {
g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
// g.P("m := new(", method.Input.GoIdent, ")")
g.Alloc("m", method.Input, true)
g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
return hname
}
const deprecationComment = "// Deprecated: Do not use."
func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }

View File

@ -0,0 +1,34 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package grpc
import (
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
)
const version = "1.1.0-vtproto"
var requireUnimplementedAlways = true
var requireUnimplemented = &requireUnimplementedAlways
func init() {
generator.RegisterFeature("grpc", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &grpc{gen}
})
}
type grpc struct {
*generator.GeneratedFile
}
func (g *grpc) GenerateFile(file *protogen.File) bool {
if len(file.Services) == 0 {
return false
}
generateFileContent(nil, file, g.GeneratedFile)
return true
}

View File

@ -0,0 +1,748 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package marshal
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
)
func init() {
generator.RegisterFeature("marshal", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &marshal{GeneratedFile: gen, Stable: false, strict: false}
})
generator.RegisterFeature("marshal_strict", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &marshal{GeneratedFile: gen, Stable: false, strict: true}
})
}
type counter int
func (cnt *counter) Next() string {
*cnt++
return cnt.Current()
}
func (cnt *counter) Current() string {
return strconv.Itoa(int(*cnt))
}
type marshal struct {
*generator.GeneratedFile
Stable, once, strict bool
}
var _ generator.FeatureGenerator = (*marshal)(nil)
func (p *marshal) GenerateFile(file *protogen.File) bool {
for _, message := range file.Messages {
p.message(message)
}
return p.once
}
func (p *marshal) encodeFixed64(varName ...string) {
p.P(`i -= 8`)
p.P(p.Ident("encoding/binary", "LittleEndian"), `.PutUint64(dAtA[i:], uint64(`, strings.Join(varName, ""), `))`)
}
func (p *marshal) encodeFixed32(varName ...string) {
p.P(`i -= 4`)
p.P(p.Ident("encoding/binary", "LittleEndian"), `.PutUint32(dAtA[i:], uint32(`, strings.Join(varName, ""), `))`)
}
func (p *marshal) encodeVarint(varName ...string) {
p.P(`i = `, p.Helper("EncodeVarint"), `(dAtA, i, uint64(`, strings.Join(varName, ""), `))`)
}
func (p *marshal) encodeKey(fieldNumber protoreflect.FieldNumber, wireType protowire.Type) {
x := uint32(fieldNumber)<<3 | uint32(wireType)
i := 0
keybuf := make([]byte, 0)
for i = 0; x > 127; i++ {
keybuf = append(keybuf, 0x80|uint8(x&0x7F))
x >>= 7
}
keybuf = append(keybuf, uint8(x))
for i = len(keybuf) - 1; i >= 0; i-- {
p.P(`i--`)
p.P(`dAtA[i] = `, fmt.Sprintf("%#v", keybuf[i]))
}
}
func (p *marshal) mapField(kvField *protogen.Field, varName string) {
switch kvField.Desc.Kind() {
case protoreflect.DoubleKind:
p.encodeFixed64(p.Ident("math", "Float64bits"), `(float64(`, varName, `))`)
case protoreflect.FloatKind:
p.encodeFixed32(p.Ident("math", "Float32bits"), `(float32(`, varName, `))`)
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.EnumKind:
p.encodeVarint(varName)
case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
p.encodeFixed64(varName)
case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
p.encodeFixed32(varName)
case protoreflect.BoolKind:
p.P(`i--`)
p.P(`if `, varName, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
case protoreflect.StringKind, protoreflect.BytesKind:
p.P(`i -= len(`, varName, `)`)
p.P(`copy(dAtA[i:], `, varName, `)`)
p.encodeVarint(`len(`, varName, `)`)
case protoreflect.Sint32Kind:
p.encodeVarint(`(uint32(`, varName, `) << 1) ^ uint32((`, varName, ` >> 31))`)
case protoreflect.Sint64Kind:
p.encodeVarint(`(uint64(`, varName, `) << 1) ^ uint64((`, varName, ` >> 63))`)
case protoreflect.MessageKind:
p.marshalBackward(varName, true, kvField.Message)
}
}
func (p *marshal) field(oneof bool, numGen *counter, field *protogen.Field) {
fieldname := field.GoName
nullable := field.Message != nil || (!oneof && field.Desc.HasPresence())
repeated := field.Desc.Cardinality() == protoreflect.Repeated
if repeated {
p.P(`if len(m.`, fieldname, `) > 0 {`)
} else if nullable {
if field.Desc.Cardinality() == protoreflect.Required {
p.P(`if m.`, fieldname, ` == nil {`)
p.P(`return 0, `, p.Ident("fmt", "Errorf"), `("proto: required field `, field.Desc.Name(), ` not set")`)
p.P(`} else {`)
} else {
p.P(`if m.`, fieldname, ` != nil {`)
}
}
packed := field.Desc.IsPacked()
wireType := generator.ProtoWireType(field.Desc.Kind())
fieldNumber := field.Desc.Number()
if packed {
wireType = protowire.BytesType
}
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.P(`f`, numGen.Next(), ` := `, p.Ident("math", "Float64bits"), `(float64(`, val, `))`)
p.encodeFixed64("f", numGen.Current())
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `) * 8`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`f`, numGen.Next(), ` := `, p.Ident("math", "Float64bits"), `(float64(`, val, `))`)
p.encodeFixed64("f", numGen.Current())
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeFixed64(p.Ident("math", "Float64bits"), `(float64(*m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeFixed64(p.Ident("math", "Float64bits"), `(float64(m.`, fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeFixed64(p.Ident("math", "Float64bits"), `(float64(m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.FloatKind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.P(`f`, numGen.Next(), ` := `, p.Ident("math", "Float32bits"), `(float32(`, val, `))`)
p.encodeFixed32("f" + numGen.Current())
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `) * 4`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`f`, numGen.Next(), ` := `, p.Ident("math", "Float32bits"), `(float32(`, val, `))`)
p.encodeFixed32("f" + numGen.Current())
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeFixed32(p.Ident("math", "Float32bits"), `(float32(*m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeFixed32(p.Ident("math", "Float32bits"), `(float32(m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeFixed32(p.Ident("math", "Float32bits"), `(float32(m.`+fieldname, `))`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.EnumKind:
if packed {
jvar := "j" + numGen.Next()
total := "pksize" + numGen.Next()
p.P(`var `, total, ` int`)
p.P(`for _, num := range m.`, fieldname, ` {`)
p.P(total, ` += `, p.Helper("SizeOfVarint"), `(uint64(num))`)
p.P(`}`)
p.P(`i -= `, total)
p.P(jvar, `:= i`)
switch field.Desc.Kind() {
case protoreflect.Int64Kind, protoreflect.Int32Kind, protoreflect.EnumKind:
p.P(`for _, num1 := range m.`, fieldname, ` {`)
p.P(`num := uint64(num1)`)
default:
p.P(`for _, num := range m.`, fieldname, ` {`)
}
p.P(`for num >= 1<<7 {`)
p.P(`dAtA[`, jvar, `] = uint8(uint64(num)&0x7f|0x80)`)
p.P(`num >>= 7`)
p.P(jvar, `++`)
p.P(`}`)
p.P(`dAtA[`, jvar, `] = uint8(num)`)
p.P(jvar, `++`)
p.P(`}`)
p.encodeVarint(total)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.encodeVarint(val)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeVarint(`*m.`, fieldname)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeVarint(`m.`, fieldname)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeVarint(`m.`, fieldname)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.encodeFixed64(val)
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `) * 8`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.encodeFixed64(val)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeFixed64("*m.", fieldname)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeFixed64("m.", fieldname)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeFixed64("m.", fieldname)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.encodeFixed32(val)
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `) * 4`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.encodeFixed32(val)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeFixed32("*m." + fieldname)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeFixed32("m." + fieldname)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeFixed32("m." + fieldname)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.BoolKind:
if packed {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i--`)
p.P(`if `, val, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.P(`}`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i--`)
p.P(`if `, val, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.P(`i--`)
p.P(`if *m.`, fieldname, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` {`)
p.P(`i--`)
p.P(`if m.`, fieldname, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.P(`i--`)
p.P(`if m.`, fieldname, ` {`)
p.P(`dAtA[i] = 1`)
p.P(`} else {`)
p.P(`dAtA[i] = 0`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.StringKind:
if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i -= len(`, val, `)`)
p.P(`copy(dAtA[i:], `, val, `)`)
p.encodeVarint(`len(`, val, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.P(`i -= len(*m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], *m.`, fieldname, `)`)
p.encodeVarint(`len(*m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if len(m.`, fieldname, `) > 0 {`)
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.GroupKind:
p.encodeKey(fieldNumber, protowire.EndGroupType)
p.marshalBackward(`m.`+fieldname, false, field.Message)
p.encodeKey(fieldNumber, protowire.StartGroupType)
case protoreflect.MessageKind:
if field.Desc.IsMap() {
goTypK, _ := p.FieldGoType(field.Message.Fields[0])
keyKind := field.Message.Fields[0].Desc.Kind()
valKind := field.Message.Fields[1].Desc.Kind()
var val string
if p.Stable && keyKind != protoreflect.BoolKind {
keysName := `keysFor` + fieldname
p.P(keysName, ` := make([]`, goTypK, `, 0, len(m.`, fieldname, `))`)
p.P(`for k := range m.`, fieldname, ` {`)
p.P(keysName, ` = append(`, keysName, `, `, goTypK, `(k))`)
p.P(`}`)
p.P(p.Ident("sort", "Slice"), `(`, keysName, `, func(i, j int) bool {`)
p.P(`return `, keysName, `[i] < `, keysName, `[j]`)
p.P(`})`)
val = p.reverseListRange(keysName)
} else {
p.P(`for k := range m.`, fieldname, ` {`)
val = "k"
}
if p.Stable {
p.P(`v := m.`, fieldname, `[`, goTypK, `(`, val, `)]`)
} else {
p.P(`v := m.`, fieldname, `[`, val, `]`)
}
p.P(`baseI := i`)
accessor := `v`
p.mapField(field.Message.Fields[1], accessor)
p.encodeKey(2, generator.ProtoWireType(valKind))
p.mapField(field.Message.Fields[0], val)
p.encodeKey(1, generator.ProtoWireType(keyKind))
p.encodeVarint(`baseI - i`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.marshalBackward(val, true, field.Message)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.marshalBackward(`m.`+fieldname, true, field.Message)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.BytesKind:
if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i -= len(`, val, `)`)
p.P(`copy(dAtA[i:], `, val, `)`)
p.encodeVarint(`len(`, val, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if !oneof && !field.Desc.HasPresence() {
p.P(`if len(m.`, fieldname, `) > 0 {`)
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.encodeVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Sint32Kind:
if packed {
jvar := "j" + numGen.Next()
total := "pksize" + numGen.Next()
p.P(`var `, total, ` int`)
p.P(`for _, num := range m.`, fieldname, ` {`)
p.P(total, ` += `, p.Helper("SizeOfZigzag"), `(uint64(num))`)
p.P(`}`)
p.P(`i -= `, total)
p.P(jvar, `:= i`)
p.P(`for _, num := range m.`, fieldname, ` {`)
xvar := "x" + numGen.Next()
p.P(xvar, ` := (uint32(num) << 1) ^ uint32((num >> 31))`)
p.P(`for `, xvar, ` >= 1<<7 {`)
p.P(`dAtA[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`)
p.P(jvar, `++`)
p.P(xvar, ` >>= 7`)
p.P(`}`)
p.P(`dAtA[`, jvar, `] = uint8(`, xvar, `)`)
p.P(jvar, `++`)
p.P(`}`)
p.encodeVarint(total)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`x`, numGen.Next(), ` := (uint32(`, val, `) << 1) ^ uint32((`, val, ` >> 31))`)
p.encodeVarint(`x`, numGen.Current())
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeVarint(`(uint32(*m.`, fieldname, `) << 1) ^ uint32((*m.`, fieldname, ` >> 31))`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`)
p.encodeKey(fieldNumber, wireType)
}
case protoreflect.Sint64Kind:
if packed {
jvar := "j" + numGen.Next()
total := "pksize" + numGen.Next()
p.P(`var `, total, ` int`)
p.P(`for _, num := range m.`, fieldname, ` {`)
p.P(total, ` += `, p.Helper("SizeOfZigzag"), `(uint64(num))`)
p.P(`}`)
p.P(`i -= `, total)
p.P(jvar, `:= i`)
p.P(`for _, num := range m.`, fieldname, ` {`)
xvar := "x" + numGen.Next()
p.P(xvar, ` := (uint64(num) << 1) ^ uint64((num >> 63))`)
p.P(`for `, xvar, ` >= 1<<7 {`)
p.P(`dAtA[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`)
p.P(jvar, `++`)
p.P(xvar, ` >>= 7`)
p.P(`}`)
p.P(`dAtA[`, jvar, `] = uint8(`, xvar, `)`)
p.P(jvar, `++`)
p.P(`}`)
p.encodeVarint(total)
p.encodeKey(fieldNumber, wireType)
} else if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`x`, numGen.Next(), ` := (uint64(`, val, `) << 1) ^ uint64((`, val, ` >> 63))`)
p.encodeVarint("x" + numGen.Current())
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if nullable {
p.encodeVarint(`(uint64(*m.`, fieldname, `) << 1) ^ uint64((*m.`, fieldname, ` >> 63))`)
p.encodeKey(fieldNumber, wireType)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.encodeVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else {
p.encodeVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`)
p.encodeKey(fieldNumber, wireType)
}
default:
panic("not implemented")
}
// Empty protobufs should emit a message or compatibility with Golang protobuf;
// See https://github.com/planetscale/vtprotobuf/issues/61
if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() {
p.P("} else {")
p.P("i = protohelpers.EncodeVarint(dAtA, i, 0)")
p.encodeKey(fieldNumber, wireType)
p.P("}")
} else if repeated || nullable {
p.P(`}`)
}
}
func (p *marshal) methodMarshalToSizedBuffer() string {
switch {
case p.strict:
return "MarshalToSizedBufferVTStrict"
default:
return "MarshalToSizedBufferVT"
}
}
func (p *marshal) methodMarshalTo() string {
switch {
case p.strict:
return "MarshalToVTStrict"
default:
return "MarshalToVT"
}
}
func (p *marshal) methodMarshal() string {
switch {
case p.strict:
return "MarshalVTStrict"
default:
return "MarshalVT"
}
}
func (p *marshal) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
var numGen counter
ccTypeName := message.GoIdent.GoName
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshal(), `() (dAtA []byte, err error) {`)
p.P(`if m == nil {`)
p.P(`return nil, nil`)
p.P(`}`)
p.P(`size := m.SizeVT()`)
p.P(`dAtA = make([]byte, size)`)
p.P(`n, err := m.`, p.methodMarshalToSizedBuffer(), `(dAtA[:size])`)
p.P(`if err != nil {`)
p.P(`return nil, err`)
p.P(`}`)
p.P(`return dAtA[:n], nil`)
p.P(`}`)
p.P(``)
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalTo(), `(dAtA []byte) (int, error) {`)
p.P(`size := m.SizeVT()`)
p.P(`return m.`, p.methodMarshalToSizedBuffer(), `(dAtA[:size])`)
p.P(`}`)
p.P(``)
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalToSizedBuffer(), `(dAtA []byte) (int, error) {`)
p.P(`if m == nil {`)
p.P(`return 0, nil`)
p.P(`}`)
p.P(`i := len(dAtA)`)
p.P(`_ = i`)
p.P(`var l int`)
p.P(`_ = l`)
if !p.Wrapper() {
p.P(`if m.unknownFields != nil {`)
p.P(`i -= len(m.unknownFields)`)
p.P(`copy(dAtA[i:], m.unknownFields)`)
p.P(`}`)
}
sort.Slice(message.Fields, func(i, j int) bool {
return message.Fields[i].Desc.Number() < message.Fields[j].Desc.Number()
})
marshalForwardOneOf := func(varname ...any) {
l := []any{`size, err := `}
l = append(l, varname...)
l = append(l, `.`, p.methodMarshalToSizedBuffer(), `(dAtA[:i])`)
p.P(l...)
p.P(`if err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.P(`i -= size`)
}
if p.strict {
for i := len(message.Fields) - 1; i >= 0; i-- {
field := message.Fields[i]
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(false, &numGen, field)
} else {
if p.IsWellKnownType(message) {
p.P(`if m, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent, `); ok {`)
p.P(`msg := ((*`, p.WellKnownFieldMap(field), `)(m))`)
} else {
p.P(`if msg, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent.GoName, `); ok {`)
}
marshalForwardOneOf("msg")
p.P(`}`)
}
}
} else {
// To match the wire format of proto.Marshal, oneofs have to be marshaled
// before fields. See https://github.com/planetscale/vtprotobuf/pull/22
oneofs := make(map[string]struct{}, len(message.Fields))
for i := len(message.Fields) - 1; i >= 0; i-- {
field := message.Fields[i]
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if oneof {
fieldname := field.Oneof.GoName
if _, ok := oneofs[fieldname]; ok {
continue
}
oneofs[fieldname] = struct{}{}
if p.IsWellKnownType(message) {
p.P(`switch c := m.`, fieldname, `.(type) {`)
for _, f := range field.Oneof.Fields {
p.P(`case *`, f.GoIdent, `:`)
marshalForwardOneOf(`(*`, p.WellKnownFieldMap(f), `)(c)`)
}
p.P(`}`)
} else {
p.P(`if vtmsg, ok := m.`, fieldname, `.(interface{`)
p.P(p.methodMarshalToSizedBuffer(), ` ([]byte) (int, error)`)
p.P(`}); ok {`)
marshalForwardOneOf("vtmsg")
p.P(`}`)
}
}
}
for i := len(message.Fields) - 1; i >= 0; i-- {
field := message.Fields[i]
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(false, &numGen, field)
}
}
}
p.P(`return len(dAtA) - i, nil`)
p.P(`}`)
p.P()
// Generate MarshalToVT methods for oneof fields
for _, field := range message.Fields {
if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() {
continue
}
ccTypeName := field.GoIdent.GoName
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalTo(), `(dAtA []byte) (int, error) {`)
p.P(`size := m.SizeVT()`)
p.P(`return m.`, p.methodMarshalToSizedBuffer(), `(dAtA[:size])`)
p.P(`}`)
p.P(``)
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalToSizedBuffer(), `(dAtA []byte) (int, error) {`)
p.P(`i := len(dAtA)`)
p.field(true, &numGen, field)
p.P(`return len(dAtA) - i, nil`)
p.P(`}`)
}
}
func (p *marshal) reverseListRange(expression ...string) string {
exp := strings.Join(expression, "")
p.P(`for iNdEx := len(`, exp, `) - 1; iNdEx >= 0; iNdEx-- {`)
return exp + `[iNdEx]`
}
func (p *marshal) marshalBackwardSize(varInt bool) {
p.P(`if err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.P(`i -= size`)
if varInt {
p.encodeVarint(`size`)
}
}
func (p *marshal) marshalBackward(varName string, varInt bool, message *protogen.Message) {
switch {
case p.IsWellKnownType(message):
p.P(`size, err := (*`, p.WellKnownTypeMap(message), `)(`, varName, `).`, p.methodMarshalToSizedBuffer(), `(dAtA[:i])`)
p.marshalBackwardSize(varInt)
case p.IsLocalMessage(message):
p.P(`size, err := `, varName, `.`, p.methodMarshalToSizedBuffer(), `(dAtA[:i])`)
p.marshalBackwardSize(varInt)
default:
p.P(`if vtmsg, ok := interface{}(`, varName, `).(interface{`)
p.P(p.methodMarshalToSizedBuffer(), `([]byte) (int, error)`)
p.P(`}); ok{`)
p.P(`size, err := vtmsg.`, p.methodMarshalToSizedBuffer(), `(dAtA[:i])`)
p.marshalBackwardSize(varInt)
p.P(`} else {`)
p.P(`encoded, err := `, p.Ident(generator.ProtoPkg, "Marshal"), `(`, varName, `)`)
p.P(`if err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.P(`i -= len(encoded)`)
p.P(`copy(dAtA[i:], encoded)`)
if varInt {
p.encodeVarint(`len(encoded)`)
}
p.P(`}`)
}
}

View File

@ -0,0 +1,105 @@
package pool
import (
"fmt"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
func init() {
generator.RegisterFeature("pool", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &pool{GeneratedFile: gen}
})
}
type pool struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*pool)(nil)
func (p *pool) GenerateFile(file *protogen.File) bool {
for _, message := range file.Messages {
p.message(message)
}
return p.once
}
func (p *pool) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(nested)
}
if message.Desc.IsMapEntry() || !p.ShouldPool(message) {
return
}
p.once = true
ccTypeName := message.GoIdent
p.P(`var vtprotoPool_`, ccTypeName, ` = `, p.Ident("sync", "Pool"), `{`)
p.P(`New: func() interface{} {`)
p.P(`return &`, ccTypeName, `{}`)
p.P(`},`)
p.P(`}`)
p.P(`func (m *`, ccTypeName, `) ResetVT() {`)
p.P(`if m != nil {`)
var saved []*protogen.Field
for _, field := range message.Fields {
fieldName := field.GoName
if field.Desc.IsList() {
switch field.Desc.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
p.P(`for _, mm := range m.`, fieldName, `{`)
if p.ShouldPool(field.Message) {
p.P(`mm.ResetVT()`)
} else {
p.P(`mm.Reset()`)
}
p.P(`}`)
}
p.P(fmt.Sprintf("f%d", len(saved)), ` := m.`, fieldName, `[:0]`)
saved = append(saved, field)
} else if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
if p.ShouldPool(field.Message) {
p.P(`if oneof, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent, `); ok {`)
p.P(`oneof.`, fieldName, `.ReturnToVTPool()`)
p.P(`}`)
}
} else {
switch field.Desc.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
if !field.Desc.IsMap() && p.ShouldPool(field.Message) {
p.P(`m.`, fieldName, `.ReturnToVTPool()`)
}
case protoreflect.BytesKind:
p.P(fmt.Sprintf("f%d", len(saved)), ` := m.`, fieldName, `[:0]`)
saved = append(saved, field)
}
}
}
p.P(`m.Reset()`)
for i, field := range saved {
p.P(`m.`, field.GoName, ` = `, fmt.Sprintf("f%d", i))
}
p.P(`}`)
p.P(`}`)
p.P(`func (m *`, ccTypeName, `) ReturnToVTPool() {`)
p.P(`if m != nil {`)
p.P(`m.ResetVT()`)
p.P(`vtprotoPool_`, ccTypeName, `.Put(m)`)
p.P(`}`)
p.P(`}`)
p.P(`func `, ccTypeName, `FromVTPool() *`, ccTypeName, `{`)
p.P(`return vtprotoPool_`, ccTypeName, `.Get().(*`, ccTypeName, `)`)
p.P(`}`)
}

View File

@ -0,0 +1,348 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package size
import (
"strconv"
"github.com/planetscale/vtprotobuf/generator"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
)
func init() {
generator.RegisterFeature("size", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &size{GeneratedFile: gen}
})
}
type size struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*size)(nil)
func (p *size) Name() string {
return "size"
}
func (p *size) GenerateFile(file *protogen.File) bool {
for _, message := range file.Messages {
p.message(message)
}
return p.once
}
func (p *size) messageSize(varName, sizeName string, message *protogen.Message) {
switch {
case p.IsWellKnownType(message):
p.P(`l = (*`, p.WellKnownTypeMap(message), `)(`, varName, `).`, sizeName, `()`)
case p.IsLocalMessage(message):
p.P(`l = `, varName, `.`, sizeName, `()`)
default:
p.P(`if size, ok := interface{}(`, varName, `).(interface{`)
p.P(sizeName, `() int`)
p.P(`}); ok{`)
p.P(`l = size.`, sizeName, `()`)
p.P(`} else {`)
p.P(`l = `, p.Ident(generator.ProtoPkg, "Size"), `(`, varName, `)`)
p.P(`}`)
}
}
func (p *size) field(oneof bool, field *protogen.Field, sizeName string) {
fieldname := field.GoName
nullable := field.Message != nil || (!oneof && field.Desc.HasPresence())
repeated := field.Desc.Cardinality() == protoreflect.Repeated
if repeated {
p.P(`if len(m.`, fieldname, `) > 0 {`)
} else if nullable {
p.P(`if m.`, fieldname, ` != nil {`)
}
packed := field.Desc.IsPacked()
wireType := generator.ProtoWireType(field.Desc.Kind())
fieldNumber := field.Desc.Number()
if packed {
wireType = protowire.BytesType
}
key := generator.KeySize(fieldNumber, wireType)
switch field.Desc.Kind() {
case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
if packed {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(len(m.`, fieldname, `)*8))`, `+len(m.`, fieldname, `)*8`)
} else if repeated {
p.P(`n+=`, strconv.Itoa(key+8), `*len(m.`, fieldname, `)`)
} else if !oneof && !nullable {
p.P(`if m.`, fieldname, ` != 0 {`)
p.P(`n+=`, strconv.Itoa(key+8))
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key+8))
}
case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
if packed {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(len(m.`, fieldname, `)*4))`, `+len(m.`, fieldname, `)*4`)
} else if repeated {
p.P(`n+=`, strconv.Itoa(key+4), `*len(m.`, fieldname, `)`)
} else if !oneof && !nullable {
p.P(`if m.`, fieldname, ` != 0 {`)
p.P(`n+=`, strconv.Itoa(key+4))
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key+4))
}
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
if packed {
p.P(`l = 0`)
p.P(`for _, e := range m.`, fieldname, ` {`)
p.P(`l+=`, p.Helper("SizeOfVarint"), `(uint64(e))`)
p.P(`}`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(l))+l`)
} else if repeated {
p.P(`for _, e := range m.`, fieldname, ` {`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(e))`)
p.P(`}`)
} else if nullable {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(*m.`, fieldname, `))`)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(m.`, fieldname, `))`)
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(m.`, fieldname, `))`)
}
case protoreflect.BoolKind:
if packed {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(len(m.`, fieldname, `)))`, `+len(m.`, fieldname, `)*1`)
} else if repeated {
p.P(`n+=`, strconv.Itoa(key+1), `*len(m.`, fieldname, `)`)
} else if !oneof && !nullable {
p.P(`if m.`, fieldname, ` {`)
p.P(`n+=`, strconv.Itoa(key+1))
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key+1))
}
case protoreflect.StringKind:
if repeated {
p.P(`for _, s := range m.`, fieldname, ` { `)
p.P(`l = len(s)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else if nullable {
p.P(`l=len(*m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
} else if !oneof {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`if l > 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
}
case protoreflect.GroupKind:
p.messageSize("m."+fieldname, sizeName, field.Message)
p.P(`n+=l+`, strconv.Itoa(2*key))
case protoreflect.MessageKind:
if field.Desc.IsMap() {
fieldKeySize := generator.KeySize(field.Desc.Number(), generator.ProtoWireType(field.Desc.Kind()))
keyKeySize := generator.KeySize(1, generator.ProtoWireType(field.Message.Fields[0].Desc.Kind()))
valueKeySize := generator.KeySize(2, generator.ProtoWireType(field.Message.Fields[1].Desc.Kind()))
p.P(`for k, v := range m.`, fieldname, ` { `)
p.P(`_ = k`)
p.P(`_ = v`)
sum := []interface{}{strconv.Itoa(keyKeySize)}
switch field.Message.Fields[0].Desc.Kind() {
case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
sum = append(sum, `8`)
case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
sum = append(sum, `4`)
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
sum = append(sum, p.Helper("SizeOfVarint"), `(uint64(k))`)
case protoreflect.BoolKind:
sum = append(sum, `1`)
case protoreflect.StringKind, protoreflect.BytesKind:
sum = append(sum, `len(k)`, p.Helper("SizeOfVarint"), `(uint64(len(k)))`)
case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
sum = append(sum, p.Helper("SizeOfZigzag"), `(uint64(k))`)
}
switch field.Message.Fields[1].Desc.Kind() {
case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, strconv.Itoa(8))
case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, strconv.Itoa(4))
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, p.Helper("SizeOfVarint"), `(uint64(v))`)
case protoreflect.BoolKind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, `1`)
case protoreflect.StringKind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, `len(v)`, p.Helper("SizeOfVarint"), `(uint64(len(v)))`)
case protoreflect.BytesKind:
p.P(`l = `, strconv.Itoa(valueKeySize), ` + len(v)+`, p.Helper("SizeOfVarint"), `(uint64(len(v)))`)
sum = append(sum, `l`)
case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
sum = append(sum, strconv.Itoa(valueKeySize))
sum = append(sum, p.Helper("SizeOfZigzag"), `(uint64(v))`)
case protoreflect.MessageKind:
p.P(`l = 0`)
p.P(`if v != nil {`)
p.messageSize("v", sizeName, field.Message.Fields[1].Message)
p.P(`}`)
p.P(`l += `, strconv.Itoa(valueKeySize), ` + `, p.Helper("SizeOfVarint"), `(uint64(l))`)
sum = append(sum, `l`)
}
mapEntrySize := []interface{}{"mapEntrySize := "}
for i, elt := range sum {
mapEntrySize = append(mapEntrySize, elt)
// if elt is not a string, then it is a helper function call
if _, ok := elt.(string); ok && i < len(sum)-1 {
mapEntrySize = append(mapEntrySize, "+")
}
}
p.P(mapEntrySize...)
p.P(`n+=mapEntrySize+`, fieldKeySize, `+`, p.Helper("SizeOfVarint"), `(uint64(mapEntrySize))`)
p.P(`}`)
} else if field.Desc.IsList() {
p.P(`for _, e := range m.`, fieldname, ` { `)
p.messageSize("e", sizeName, field.Message)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else {
p.messageSize("m."+fieldname, sizeName, field.Message)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
}
case protoreflect.BytesKind:
if repeated {
p.P(`for _, b := range m.`, fieldname, ` { `)
p.P(`l = len(b)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else if !oneof && !field.Desc.HasPresence() {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`if l > 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
p.P(`}`)
} else {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+`, p.Helper("SizeOfVarint"), `(uint64(l))`)
}
case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
if packed {
p.P(`l = 0`)
p.P(`for _, e := range m.`, fieldname, ` {`)
p.P(`l+=`, p.Helper("SizeOfZigzag"), `(uint64(e))`)
p.P(`}`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfVarint"), `(uint64(l))+l`)
} else if repeated {
p.P(`for _, e := range m.`, fieldname, ` {`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfZigzag"), `(uint64(e))`)
p.P(`}`)
} else if nullable {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfZigzag"), `(uint64(*m.`, fieldname, `))`)
} else if !oneof {
p.P(`if m.`, fieldname, ` != 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfZigzag"), `(uint64(m.`, fieldname, `))`)
p.P(`}`)
} else {
p.P(`n+=`, strconv.Itoa(key), `+`, p.Helper("SizeOfZigzag"), `(uint64(m.`, fieldname, `))`)
}
default:
panic("not implemented")
}
// Empty protobufs should emit a message or compatibility with Golang protobuf;
// See https://github.com/planetscale/vtprotobuf/issues/61
// Size is always 3 so just hardcode that here
if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() {
p.P("} else { n += 3 }")
} else if repeated || nullable {
p.P(`}`)
}
}
func (p *size) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
sizeName := "SizeVT"
ccTypeName := message.GoIdent.GoName
p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
p.P(`if m == nil {`)
p.P(`return 0`)
p.P(`}`)
p.P(`var l int`)
p.P(`_ = l`)
oneofs := make(map[string]struct{})
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(false, field, sizeName)
} else {
fieldname := field.Oneof.GoName
if _, ok := oneofs[fieldname]; ok {
continue
}
oneofs[fieldname] = struct{}{}
if p.IsWellKnownType(message) {
p.P(`switch c := m.`, fieldname, `.(type) {`)
for _, f := range field.Oneof.Fields {
p.P(`case *`, f.GoIdent, `:`)
p.P(`n += (*`, p.WellKnownFieldMap(f), `)(c).`, sizeName, `()`)
}
p.P(`}`)
} else {
p.P(`if vtmsg, ok := m.`, fieldname, `.(interface{ SizeVT() int }); ok {`)
p.P(`n+=vtmsg.`, sizeName, `()`)
p.P(`}`)
}
}
}
if !p.Wrapper() {
p.P(`n+=len(m.unknownFields)`)
}
p.P(`return n`)
p.P(`}`)
p.P()
for _, field := range message.Fields {
if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() {
continue
}
ccTypeName := field.GoIdent
if p.IsWellKnownType(message) && p.IsLocalMessage(message) {
ccTypeName.GoImportPath = ""
}
p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
p.P(`if m == nil {`)
p.P(`return 0`)
p.P(`}`)
p.P(`var l int`)
p.P(`_ = l`)
p.field(true, field, sizeName)
p.P(`return n`)
p.P(`}`)
}
}

View File

@ -0,0 +1,872 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unmarshal
import (
"fmt"
"strconv"
"strings"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
func init() {
generator.RegisterFeature("unmarshal", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &unmarshal{GeneratedFile: gen}
})
generator.RegisterFeature("unmarshal_unsafe", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &unmarshal{GeneratedFile: gen, unsafe: true}
})
}
type unmarshal struct {
*generator.GeneratedFile
unsafe bool
once bool
}
var _ generator.FeatureGenerator = (*unmarshal)(nil)
func (p *unmarshal) GenerateFile(file *protogen.File) bool {
proto3 := file.Desc.Syntax() == protoreflect.Proto3
for _, message := range file.Messages {
p.message(proto3, message)
}
return p.once
}
func (p *unmarshal) methodUnmarshal() string {
if p.unsafe {
return "UnmarshalVTUnsafe"
}
return "UnmarshalVT"
}
func (p *unmarshal) decodeMessage(varName, buf string, message *protogen.Message) {
switch {
case p.IsWellKnownType(message):
p.P(`if err := (*`, p.WellKnownTypeMap(message), `)(`, varName, `).`, p.methodUnmarshal(), `(`, buf, `); err != nil {`)
p.P(`return err`)
p.P(`}`)
case p.IsLocalMessage(message):
p.P(`if err := `, varName, `.`, p.methodUnmarshal(), `(`, buf, `); err != nil {`)
p.P(`return err`)
p.P(`}`)
default:
p.P(`if unmarshal, ok := interface{}(`, varName, `).(interface{`)
p.P(p.methodUnmarshal(), `([]byte) error`)
p.P(`}); ok{`)
p.P(`if err := unmarshal.`, p.methodUnmarshal(), `(`, buf, `); err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`} else {`)
p.P(`if err := `, p.Ident(generator.ProtoPkg, "Unmarshal"), `(`, buf, `, `, varName, `); err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`}`)
}
}
func (p *unmarshal) decodeVarint(varName string, typName string) {
p.P(`for shift := uint(0); ; shift += 7 {`)
p.P(`if shift >= 64 {`)
p.P(`return `, p.Helper("ErrIntOverflow"))
p.P(`}`)
p.P(`if iNdEx >= l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(`b := dAtA[iNdEx]`)
p.P(`iNdEx++`)
p.P(varName, ` |= `, typName, `(b&0x7F) << shift`)
p.P(`if b < 0x80 {`)
p.P(`break`)
p.P(`}`)
p.P(`}`)
}
func (p *unmarshal) decodeFixed32(varName string, typeName string) {
p.P(`if (iNdEx+4) > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(varName, ` = `, typeName, `(`, p.Ident("encoding/binary", "LittleEndian"), `.Uint32(dAtA[iNdEx:]))`)
p.P(`iNdEx += 4`)
}
func (p *unmarshal) decodeFixed64(varName string, typeName string) {
p.P(`if (iNdEx+8) > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(varName, ` = `, typeName, `(`, p.Ident("encoding/binary", "LittleEndian"), `.Uint64(dAtA[iNdEx:]))`)
p.P(`iNdEx += 8`)
}
func (p *unmarshal) declareMapField(varName string, nullable bool, field *protogen.Field) {
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
p.P(`var `, varName, ` float64`)
case protoreflect.FloatKind:
p.P(`var `, varName, ` float32`)
case protoreflect.Int64Kind:
p.P(`var `, varName, ` int64`)
case protoreflect.Uint64Kind:
p.P(`var `, varName, ` uint64`)
case protoreflect.Int32Kind:
p.P(`var `, varName, ` int32`)
case protoreflect.Fixed64Kind:
p.P(`var `, varName, ` uint64`)
case protoreflect.Fixed32Kind:
p.P(`var `, varName, ` uint32`)
case protoreflect.BoolKind:
p.P(`var `, varName, ` bool`)
case protoreflect.StringKind:
p.P(`var `, varName, ` `, field.GoIdent)
case protoreflect.MessageKind:
msgname := field.GoIdent
if nullable {
p.P(`var `, varName, ` *`, msgname)
} else {
p.P(varName, ` := &`, msgname, `{}`)
}
case protoreflect.BytesKind:
p.P(varName, ` := []byte{}`)
case protoreflect.Uint32Kind:
p.P(`var `, varName, ` uint32`)
case protoreflect.EnumKind:
p.P(`var `, varName, ` `, field.GoIdent)
case protoreflect.Sfixed32Kind:
p.P(`var `, varName, ` int32`)
case protoreflect.Sfixed64Kind:
p.P(`var `, varName, ` int64`)
case protoreflect.Sint32Kind:
p.P(`var `, varName, ` int32`)
case protoreflect.Sint64Kind:
p.P(`var `, varName, ` int64`)
}
}
func (p *unmarshal) mapField(varName string, field *protogen.Field) {
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
p.P(`var `, varName, `temp uint64`)
p.decodeFixed64(varName+"temp", "uint64")
p.P(varName, ` = `, p.Ident("math", "Float64frombits"), `(`, varName, `temp)`)
case protoreflect.FloatKind:
p.P(`var `, varName, `temp uint32`)
p.decodeFixed32(varName+"temp", "uint32")
p.P(varName, ` = `, p.Ident("math", "Float32frombits"), `(`, varName, `temp)`)
case protoreflect.Int64Kind:
p.decodeVarint(varName, "int64")
case protoreflect.Uint64Kind:
p.decodeVarint(varName, "uint64")
case protoreflect.Int32Kind:
p.decodeVarint(varName, "int32")
case protoreflect.Fixed64Kind:
p.decodeFixed64(varName, "uint64")
case protoreflect.Fixed32Kind:
p.decodeFixed32(varName, "uint32")
case protoreflect.BoolKind:
p.P(`var `, varName, `temp int`)
p.decodeVarint(varName+"temp", "int")
p.P(varName, ` = bool(`, varName, `temp != 0)`)
case protoreflect.StringKind:
p.P(`var stringLen`, varName, ` uint64`)
p.decodeVarint("stringLen"+varName, "uint64")
p.P(`intStringLen`, varName, ` := int(stringLen`, varName, `)`)
p.P(`if intStringLen`, varName, ` < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postStringIndex`, varName, ` := iNdEx + intStringLen`, varName)
p.P(`if postStringIndex`, varName, ` < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postStringIndex`, varName, ` > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if p.unsafe {
p.P(`if intStringLen`, varName, ` == 0 {`)
p.P(varName, ` = ""`)
p.P(`} else {`)
p.P(varName, ` = `, p.Ident("unsafe", `String`), `(&dAtA[iNdEx], intStringLen`, varName, `)`)
p.P(`}`)
} else {
p.P(varName, ` = `, "string", `(dAtA[iNdEx:postStringIndex`, varName, `])`)
}
p.P(`iNdEx = postStringIndex`, varName)
case protoreflect.MessageKind:
p.P(`var mapmsglen int`)
p.decodeVarint("mapmsglen", "int")
p.P(`if mapmsglen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postmsgIndex := iNdEx + mapmsglen`)
p.P(`if postmsgIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postmsgIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
buf := `dAtA[iNdEx:postmsgIndex]`
p.P(varName, ` = &`, p.noStarOrSliceType(field), `{}`)
p.decodeMessage(varName, buf, field.Message)
p.P(`iNdEx = postmsgIndex`)
case protoreflect.BytesKind:
p.P(`var mapbyteLen uint64`)
p.decodeVarint("mapbyteLen", "uint64")
p.P(`intMapbyteLen := int(mapbyteLen)`)
p.P(`if intMapbyteLen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postbytesIndex := iNdEx + intMapbyteLen`)
p.P(`if postbytesIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postbytesIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if p.unsafe {
p.P(varName, ` = dAtA[iNdEx:postbytesIndex]`)
} else {
p.P(varName, ` = make([]byte, mapbyteLen)`)
p.P(`copy(`, varName, `, dAtA[iNdEx:postbytesIndex])`)
}
p.P(`iNdEx = postbytesIndex`)
case protoreflect.Uint32Kind:
p.decodeVarint(varName, "uint32")
case protoreflect.EnumKind:
goTypV, _ := p.FieldGoType(field)
p.decodeVarint(varName, goTypV)
case protoreflect.Sfixed32Kind:
p.decodeFixed32(varName, "int32")
case protoreflect.Sfixed64Kind:
p.decodeFixed64(varName, "int64")
case protoreflect.Sint32Kind:
p.P(`var `, varName, `temp int32`)
p.decodeVarint(varName+"temp", "int32")
p.P(varName, `temp = int32((uint32(`, varName, `temp) >> 1) ^ uint32(((`, varName, `temp&1)<<31)>>31))`)
p.P(varName, ` = int32(`, varName, `temp)`)
case protoreflect.Sint64Kind:
p.P(`var `, varName, `temp uint64`)
p.decodeVarint(varName+"temp", "uint64")
p.P(varName, `temp = (`, varName, `temp >> 1) ^ uint64((int64(`, varName, `temp&1)<<63)>>63)`)
p.P(varName, ` = int64(`, varName, `temp)`)
}
}
func (p *unmarshal) noStarOrSliceType(field *protogen.Field) string {
typ, _ := p.FieldGoType(field)
if typ[0] == '[' && typ[1] == ']' {
typ = typ[2:]
}
if typ[0] == '*' {
typ = typ[1:]
}
return typ
}
func (p *unmarshal) fieldItem(field *protogen.Field, fieldname string, message *protogen.Message, proto3 bool) {
repeated := field.Desc.Cardinality() == protoreflect.Repeated
typ := p.noStarOrSliceType(field)
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
nullable := field.Oneof != nil && field.Oneof.Desc.IsSynthetic()
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
p.P(`var v uint64`)
p.decodeFixed64("v", "uint64")
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", typ, "(", p.Ident("math", `Float64frombits`), `(v))}`)
} else if repeated {
p.P(`v2 := `, typ, "(", p.Ident("math", "Float64frombits"), `(v))`)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v2)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, typ, "(", p.Ident("math", "Float64frombits"), `(v))`)
} else {
p.P(`v2 := `, typ, "(", p.Ident("math", "Float64frombits"), `(v))`)
p.P(`m.`, fieldname, ` = &v2`)
}
case protoreflect.FloatKind:
p.P(`var v uint32`)
p.decodeFixed32("v", "uint32")
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", typ, "(", p.Ident("math", "Float32frombits"), `(v))}`)
} else if repeated {
p.P(`v2 := `, typ, "(", p.Ident("math", "Float32frombits"), `(v))`)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v2)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, typ, "(", p.Ident("math", "Float32frombits"), `(v))`)
} else {
p.P(`v2 := `, typ, "(", p.Ident("math", "Float32frombits"), `(v))`)
p.P(`m.`, fieldname, ` = &v2`)
}
case protoreflect.Int64Kind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Uint64Kind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Int32Kind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Fixed64Kind:
if oneof {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeFixed64("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Fixed32Kind:
if oneof {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeFixed32("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.BoolKind:
p.P(`var v int`)
p.decodeVarint("v", "int")
if oneof {
p.P(`b := `, typ, `(v != 0)`)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: b}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(v != 0))`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, typ, `(v != 0)`)
} else {
p.P(`b := `, typ, `(v != 0)`)
p.P(`m.`, fieldname, ` = &b`)
}
case protoreflect.StringKind:
p.P(`var stringLen uint64`)
p.decodeVarint("stringLen", "uint64")
p.P(`intStringLen := int(stringLen)`)
p.P(`if intStringLen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postIndex := iNdEx + intStringLen`)
p.P(`if postIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
str := "string(dAtA[iNdEx:postIndex])"
if p.unsafe {
str = "stringValue"
p.P(`var stringValue string`)
p.P(`if intStringLen > 0 {`)
p.P(`stringValue = `, p.Ident("unsafe", `String`), `(&dAtA[iNdEx], intStringLen)`)
p.P(`}`)
}
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", str, `}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, str, `)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, str)
} else {
p.P(`s := `, str)
p.P(`m.`, fieldname, ` = &s`)
}
p.P(`iNdEx = postIndex`)
case protoreflect.GroupKind:
p.P(`groupStart := iNdEx`)
p.P(`for {`)
p.P(`maybeGroupEnd := iNdEx`)
p.P(`var groupFieldWire uint64`)
p.decodeVarint("groupFieldWire", "uint64")
p.P(`groupWireType := int(wire & 0x7)`)
p.P(`if groupWireType == `, strconv.Itoa(int(protowire.EndGroupType)), `{`)
p.decodeMessage("m."+fieldname, "dAtA[groupStart:maybeGroupEnd]", field.Message)
p.P(`break`)
p.P(`}`)
p.P(`skippy, err := `, p.Helper("Skip"), `(dAtA[iNdEx:])`)
p.P(`if err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`if (skippy < 0) || (iNdEx + skippy) < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`iNdEx += skippy`)
p.P(`}`)
case protoreflect.MessageKind:
p.P(`var msglen int`)
p.decodeVarint("msglen", "int")
p.P(`if msglen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postIndex := iNdEx + msglen`)
p.P(`if postIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if oneof {
buf := `dAtA[iNdEx:postIndex]`
msgname := p.noStarOrSliceType(field)
p.P(`if oneof, ok := m.`, fieldname, `.(*`, field.GoIdent, `); ok {`)
p.decodeMessage("oneof."+field.GoName, buf, field.Message)
p.P(`} else {`)
if p.ShouldPool(message) && p.ShouldPool(field.Message) {
p.P(`v := `, msgname, `FromVTPool()`)
} else {
p.P(`v := &`, msgname, `{}`)
}
p.decodeMessage("v", buf, field.Message)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
p.P(`}`)
} else if field.Desc.IsMap() {
goTyp, _ := p.FieldGoType(field)
goTypK, _ := p.FieldGoType(field.Message.Fields[0])
goTypV, _ := p.FieldGoType(field.Message.Fields[1])
p.P(`if m.`, fieldname, ` == nil {`)
p.P(`m.`, fieldname, ` = make(`, goTyp, `)`)
p.P(`}`)
p.P("var mapkey ", goTypK)
p.P("var mapvalue ", goTypV)
p.P(`for iNdEx < postIndex {`)
p.P(`entryPreIndex := iNdEx`)
p.P(`var wire uint64`)
p.decodeVarint("wire", "uint64")
p.P(`fieldNum := int32(wire >> 3)`)
p.P(`if fieldNum == 1 {`)
p.mapField("mapkey", field.Message.Fields[0])
p.P(`} else if fieldNum == 2 {`)
p.mapField("mapvalue", field.Message.Fields[1])
p.P(`} else {`)
p.P(`iNdEx = entryPreIndex`)
p.P(`skippy, err := `, p.Helper("Skip"), `(dAtA[iNdEx:])`)
p.P(`if err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`if (skippy < 0) || (iNdEx + skippy) < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if (iNdEx + skippy) > postIndex {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(`iNdEx += skippy`)
p.P(`}`)
p.P(`}`)
p.P(`m.`, fieldname, `[mapkey] = mapvalue`)
} else if repeated {
if p.ShouldPool(message) {
p.P(`if len(m.`, fieldname, `) == cap(m.`, fieldname, `) {`)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, &`, field.Message.GoIdent, `{})`)
p.P(`} else {`)
p.P(`m.`, fieldname, ` = m.`, fieldname, `[:len(m.`, fieldname, `) + 1]`)
p.P(`if m.`, fieldname, `[len(m.`, fieldname, `) - 1] == nil {`)
p.P(`m.`, fieldname, `[len(m.`, fieldname, `) - 1] = &`, field.Message.GoIdent, `{}`)
p.P(`}`)
p.P(`}`)
} else {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, &`, field.Message.GoIdent, `{})`)
}
varname := fmt.Sprintf("m.%s[len(m.%s) - 1]", fieldname, fieldname)
buf := `dAtA[iNdEx:postIndex]`
p.decodeMessage(varname, buf, field.Message)
} else {
p.P(`if m.`, fieldname, ` == nil {`)
if p.ShouldPool(message) && p.ShouldPool(field.Message) {
p.P(`m.`, fieldname, ` = `, field.Message.GoIdent, `FromVTPool()`)
} else {
p.P(`m.`, fieldname, ` = &`, field.Message.GoIdent, `{}`)
}
p.P(`}`)
p.decodeMessage("m."+fieldname, "dAtA[iNdEx:postIndex]", field.Message)
}
p.P(`iNdEx = postIndex`)
case protoreflect.BytesKind:
p.P(`var byteLen int`)
p.decodeVarint("byteLen", "int")
p.P(`if byteLen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postIndex := iNdEx + byteLen`)
p.P(`if postIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postIndex > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if oneof {
if p.unsafe {
p.P(`v := dAtA[iNdEx:postIndex]`)
} else {
p.P(`v := make([]byte, postIndex-iNdEx)`)
p.P(`copy(v, dAtA[iNdEx:postIndex])`)
}
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
if p.unsafe {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, dAtA[iNdEx:postIndex])`)
} else {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, make([]byte, postIndex-iNdEx))`)
p.P(`copy(m.`, fieldname, `[len(m.`, fieldname, `)-1], dAtA[iNdEx:postIndex])`)
}
} else {
if p.unsafe {
p.P(`m.`, fieldname, ` = dAtA[iNdEx:postIndex]`)
} else {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `[:0] , dAtA[iNdEx:postIndex]...)`)
p.P(`if m.`, fieldname, ` == nil {`)
p.P(`m.`, fieldname, ` = []byte{}`)
p.P(`}`)
}
}
p.P(`iNdEx = postIndex`)
case protoreflect.Uint32Kind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.EnumKind:
if oneof {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeVarint("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Sfixed32Kind:
if oneof {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeFixed32("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeFixed32("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Sfixed64Kind:
if oneof {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = 0`)
p.decodeFixed64("m."+fieldname, typ)
} else {
p.P(`var v `, typ)
p.decodeFixed64("v", typ)
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Sint32Kind:
p.P(`var v `, typ)
p.decodeVarint("v", typ)
p.P(`v = `, typ, `((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31))`)
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = v`)
} else {
p.P(`m.`, fieldname, ` = &v`)
}
case protoreflect.Sint64Kind:
p.P(`var v uint64`)
p.decodeVarint("v", "uint64")
p.P(`v = (v >> 1) ^ uint64((int64(v&1)<<63)>>63)`)
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", typ, `(v)}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(v))`)
} else if proto3 && !nullable {
p.P(`m.`, fieldname, ` = `, typ, `(v)`)
} else {
p.P(`v2 := `, typ, `(v)`)
p.P(`m.`, fieldname, ` = &v2`)
}
default:
panic("not implemented")
}
}
func (p *unmarshal) field(proto3, oneof bool, field *protogen.Field, message *protogen.Message, required protoreflect.FieldNumbers) {
fieldname := field.GoName
errFieldname := fieldname
if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
fieldname = field.Oneof.GoName
}
p.P(`case `, strconv.Itoa(int(field.Desc.Number())), `:`)
wireType := generator.ProtoWireType(field.Desc.Kind())
if field.Desc.IsList() && wireType != protowire.BytesType {
p.P(`if wireType == `, strconv.Itoa(int(wireType)), `{`)
p.fieldItem(field, fieldname, message, false)
p.P(`} else if wireType == `, strconv.Itoa(int(protowire.BytesType)), `{`)
p.P(`var packedLen int`)
p.decodeVarint("packedLen", "int")
p.P(`if packedLen < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`postIndex := iNdEx + packedLen`)
p.P(`if postIndex < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if postIndex > l {`)
p.P(`return `, p.Ident("io", "ErrUnexpectedEOF"))
p.P(`}`)
p.P(`var elementCount int`)
switch field.Desc.Kind() {
case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
p.P(`elementCount = packedLen/`, 8)
case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
p.P(`elementCount = packedLen/`, 4)
case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.Sint32Kind, protoreflect.Sint64Kind:
p.P(`var count int`)
p.P(`for _, integer := range dAtA[iNdEx:postIndex] {`)
p.P(`if integer < 128 {`)
p.P(`count++`)
p.P(`}`)
p.P(`}`)
p.P(`elementCount = count`)
case protoreflect.BoolKind:
p.P(`elementCount = packedLen`)
}
if p.ShouldPool(message) {
p.P(`if elementCount != 0 && len(m.`, fieldname, `) == 0 && cap(m.`, fieldname, `) < elementCount {`)
} else {
p.P(`if elementCount != 0 && len(m.`, fieldname, `) == 0 {`)
}
fieldtyp, _ := p.FieldGoType(field)
p.P(`m.`, fieldname, ` = make(`, fieldtyp, `, 0, elementCount)`)
p.P(`}`)
p.P(`for iNdEx < postIndex {`)
p.fieldItem(field, fieldname, message, false)
p.P(`}`)
p.P(`} else {`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: wrong wireType = %d for field `, errFieldname, `", wireType)`)
p.P(`}`)
} else {
p.P(`if wireType != `, strconv.Itoa(int(wireType)), `{`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: wrong wireType = %d for field `, errFieldname, `", wireType)`)
p.P(`}`)
p.fieldItem(field, fieldname, message, proto3)
}
if field.Desc.Cardinality() == protoreflect.Required {
var fieldBit int
for fieldBit = 0; fieldBit < required.Len(); fieldBit++ {
if required.Get(fieldBit) == field.Desc.Number() {
break
}
}
if fieldBit == required.Len() {
panic("missing required field")
}
p.P(`hasFields[`, strconv.Itoa(fieldBit/64), `] |= uint64(`, fmt.Sprintf("0x%08x", uint64(1)<<(fieldBit%64)), `)`)
}
}
func (p *unmarshal) message(proto3 bool, message *protogen.Message) {
for _, nested := range message.Messages {
p.message(proto3, nested)
}
if message.Desc.IsMapEntry() {
return
}
p.once = true
ccTypeName := message.GoIdent.GoName
required := message.Desc.RequiredNumbers()
p.P(`func (m *`, ccTypeName, `) `, p.methodUnmarshal(), `(dAtA []byte) error {`)
if required.Len() > 0 {
p.P(`var hasFields [`, strconv.Itoa(1+(required.Len()-1)/64), `]uint64`)
}
p.P(`l := len(dAtA)`)
p.P(`iNdEx := 0`)
p.P(`for iNdEx < l {`)
p.P(`preIndex := iNdEx`)
p.P(`var wire uint64`)
p.decodeVarint("wire", "uint64")
p.P(`fieldNum := int32(wire >> 3)`)
p.P(`wireType := int(wire & 0x7)`)
p.P(`if wireType == `, strconv.Itoa(int(protowire.EndGroupType)), ` {`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: `, message.GoIdent.GoName, `: wiretype end group for non-group")`)
p.P(`}`)
p.P(`if fieldNum <= 0 {`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: `, message.GoIdent.GoName, `: illegal tag %d (wire type %d)", fieldNum, wire)`)
p.P(`}`)
p.P(`switch fieldNum {`)
for _, field := range message.Fields {
p.field(proto3, false, field, message, required)
}
p.P(`default:`)
p.P(`iNdEx=preIndex`)
p.P(`skippy, err := `, p.Helper("Skip"), `(dAtA[iNdEx:])`)
p.P(`if err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`if (skippy < 0) || (iNdEx + skippy) < 0 {`)
p.P(`return `, p.Helper("ErrInvalidLength"))
p.P(`}`)
p.P(`if (iNdEx + skippy) > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if message.Desc.ExtensionRanges().Len() > 0 {
c := []string{}
eranges := message.Desc.ExtensionRanges()
for e := 0; e < eranges.Len(); e++ {
erange := eranges.Get(e)
c = append(c, `((fieldNum >= `+strconv.Itoa(int(erange[0]))+`) && (fieldNum < `+strconv.Itoa(int(erange[1]))+`))`)
}
p.P(`if `, strings.Join(c, "||"), `{`)
p.P(`err = `, p.Ident(generator.ProtoPkg, "UnmarshalOptions"), `{AllowPartial: true}.Unmarshal(dAtA[iNdEx:iNdEx+skippy], m)`)
p.P(`if err != nil {`)
p.P(`return err`)
p.P(`}`)
p.P(`iNdEx += skippy`)
p.P(`} else {`)
}
if !p.Wrapper() {
p.P(`m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...)`)
}
p.P(`iNdEx += skippy`)
if message.Desc.ExtensionRanges().Len() > 0 {
p.P(`}`)
}
p.P(`}`)
p.P(`}`)
for _, field := range message.Fields {
if field.Desc.Cardinality() != protoreflect.Required {
continue
}
var fieldBit int
for fieldBit = 0; fieldBit < required.Len(); fieldBit++ {
if required.Get(fieldBit) == field.Desc.Number() {
break
}
}
if fieldBit == required.Len() {
panic("missing required field")
}
p.P(`if hasFields[`, strconv.Itoa(int(fieldBit/64)), `] & uint64(`, fmt.Sprintf("0x%08x", uint64(1)<<(fieldBit%64)), `) == 0 {`)
p.P(`return `, p.Ident("fmt", "Errorf"), `("proto: required field `, field.Desc.Name(), ` not set")`)
p.P(`}`)
}
p.P()
p.P(`if iNdEx > l {`)
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
p.P(`return nil`)
p.P(`}`)
}

View File

@ -0,0 +1,58 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"fmt"
"sort"
"google.golang.org/protobuf/compiler/protogen"
)
var defaultFeatures = make(map[string]Feature)
func findFeatures(featureNames []string) ([]Feature, error) {
required := make(map[string]Feature)
for _, name := range featureNames {
if name == "all" {
required = defaultFeatures
break
}
feat, ok := defaultFeatures[name]
if !ok {
return nil, fmt.Errorf("unknown feature: %q", name)
}
required[name] = feat
}
type namefeat struct {
name string
feat Feature
}
var sorted []namefeat
for name, feat := range required {
sorted = append(sorted, namefeat{name, feat})
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].name < sorted[j].name
})
var features []Feature
for _, sp := range sorted {
features = append(features, sp.feat)
}
return features, nil
}
func RegisterFeature(name string, feat Feature) {
defaultFeatures[name] = feat
}
type Feature func(gen *GeneratedFile) FeatureGenerator
type FeatureGenerator interface {
GenerateFile(file *protogen.File) bool
}

View File

@ -0,0 +1,198 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"fmt"
"github.com/planetscale/vtprotobuf/vtproto"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
type GeneratedFile struct {
*protogen.GeneratedFile
Config *Config
LocalPackages map[protoreflect.FullName]bool
}
func (p *GeneratedFile) Ident(path, ident string) string {
return p.QualifiedGoIdent(protogen.GoImportPath(path).Ident(ident))
}
func (b *GeneratedFile) ShouldPool(message *protogen.Message) bool {
// Do not generate pool if message is nil or message excluded by external rules
if message == nil || b.Config.PoolableExclude.Contains(message.GoIdent) {
return false
}
if b.Config.Poolable.Contains(message.GoIdent) {
return true
}
ext := proto.GetExtension(message.Desc.Options(), vtproto.E_Mempool)
if mempool, ok := ext.(bool); ok {
return mempool
}
return false
}
func (b *GeneratedFile) Alloc(vname string, message *protogen.Message, isQualifiedIdent bool) {
ident := message.GoIdent.GoName
if isQualifiedIdent {
ident = b.QualifiedGoIdent(message.GoIdent)
}
if b.ShouldPool(message) {
b.P(vname, " := ", ident, `FromVTPool()`)
} else {
b.P(vname, " := new(", ident, `)`)
}
}
func (p *GeneratedFile) FieldGoType(field *protogen.Field) (goType string, pointer bool) {
if field.Desc.IsWeak() {
return "struct{}", false
}
pointer = field.Desc.HasPresence()
switch field.Desc.Kind() {
case protoreflect.BoolKind:
goType = "bool"
case protoreflect.EnumKind:
goType = p.QualifiedGoIdent(field.Enum.GoIdent)
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
goType = "int32"
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
goType = "uint32"
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
goType = "int64"
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
goType = "uint64"
case protoreflect.FloatKind:
goType = "float32"
case protoreflect.DoubleKind:
goType = "float64"
case protoreflect.StringKind:
goType = "string"
case protoreflect.BytesKind:
goType = "[]byte"
pointer = false // rely on nullability of slices for presence
case protoreflect.MessageKind, protoreflect.GroupKind:
goType = "*" + p.QualifiedGoIdent(field.Message.GoIdent)
pointer = false // pointer captured as part of the type
}
switch {
case field.Desc.IsList():
return "[]" + goType, false
case field.Desc.IsMap():
keyType, _ := p.FieldGoType(field.Message.Fields[0])
valType, _ := p.FieldGoType(field.Message.Fields[1])
return fmt.Sprintf("map[%v]%v", keyType, valType), false
}
return goType, pointer
}
func (p *GeneratedFile) IsLocalMessage(message *protogen.Message) bool {
if message == nil {
return false
}
pkg := message.Desc.ParentFile().Package()
return p.LocalPackages[pkg]
}
func (p *GeneratedFile) IsLocalField(field *protogen.Field) bool {
if field == nil {
return false
}
pkg := field.Desc.ParentFile().Package()
return p.LocalPackages[pkg]
}
const vtHelpersPackage = protogen.GoImportPath("github.com/planetscale/vtprotobuf/protohelpers")
var helpers = map[string]protogen.GoIdent{
"EncodeVarint": {GoName: "EncodeVarint", GoImportPath: vtHelpersPackage},
"SizeOfVarint": {GoName: "SizeOfVarint", GoImportPath: vtHelpersPackage},
"SizeOfZigzag": {GoName: "SizeOfZigzag", GoImportPath: vtHelpersPackage},
"Skip": {GoName: "Skip", GoImportPath: vtHelpersPackage},
"ErrInvalidLength": {GoName: "ErrInvalidLength", GoImportPath: vtHelpersPackage},
"ErrIntOverflow": {GoName: "ErrIntOverflow", GoImportPath: vtHelpersPackage},
"ErrUnexpectedEndOfGroup": {GoName: "ErrUnexpectedEndOfGroup", GoImportPath: vtHelpersPackage},
}
func (p *GeneratedFile) Helper(name string) protogen.GoIdent {
return helpers[name]
}
const vtWellKnownPackage = protogen.GoImportPath("github.com/planetscale/vtprotobuf/types/known/")
var wellKnownTypes = map[protoreflect.FullName]protogen.GoIdent{
"google.protobuf.Any": {GoName: "Any", GoImportPath: vtWellKnownPackage + "anypb"},
"google.protobuf.Duration": {GoName: "Duration", GoImportPath: vtWellKnownPackage + "durationpb"},
"google.protobuf.Empty": {GoName: "Empty", GoImportPath: vtWellKnownPackage + "emptypb"},
"google.protobuf.FieldMask": {GoName: "FieldMask", GoImportPath: vtWellKnownPackage + "fieldmaskpb"},
"google.protobuf.Timestamp": {GoName: "Timestamp", GoImportPath: vtWellKnownPackage + "timestamppb"},
"google.protobuf.DoubleValue": {GoName: "DoubleValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.FloatValue": {GoName: "FloatValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.Int64Value": {GoName: "Int64Value", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.UInt64Value": {GoName: "UInt64Value", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.Int32Value": {GoName: "Int32Value", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.UInt32Value": {GoName: "UInt32Value", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.BoolValue": {GoName: "BoolValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.StringValue": {GoName: "StringValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.BytesValue": {GoName: "BytesValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.Struct": {GoName: "Struct", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value": {GoName: "Value", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.ListValue": {GoName: "ListValue", GoImportPath: vtWellKnownPackage + "structpb"},
}
var wellKnownFields = map[protoreflect.FullName]protogen.GoIdent{
"google.protobuf.Value.null_value": {GoName: "Value_NullValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.number_value": {GoName: "Value_NumberValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.string_value": {GoName: "Value_StringValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.bool_value": {GoName: "Value_BoolValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.struct_value": {GoName: "Value_StructValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.list_value": {GoName: "Value_ListValue", GoImportPath: vtWellKnownPackage + "structpb"},
}
func (p *GeneratedFile) IsWellKnownType(message *protogen.Message) bool {
if message == nil {
return false
}
_, ok := wellKnownTypes[message.Desc.FullName()]
return ok
}
func (p *GeneratedFile) WellKnownFieldMap(field *protogen.Field) protogen.GoIdent {
if field == nil {
return protogen.GoIdent{}
}
res, ff := wellKnownFields[field.Desc.FullName()]
if !ff {
panic(field.Desc.FullName())
}
if p.IsLocalField(field) {
res.GoImportPath = ""
}
return res
}
func (p *GeneratedFile) WellKnownTypeMap(message *protogen.Message) protogen.GoIdent {
if message == nil {
return protogen.GoIdent{}
}
res := wellKnownTypes[message.Desc.FullName()]
if p.IsLocalMessage(message) {
res.GoImportPath = ""
}
return res
}
func (p *GeneratedFile) Wrapper() bool {
return p.Config.Wrap
}

View File

@ -0,0 +1,165 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"fmt"
"runtime/debug"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/types/pluginpb"
"github.com/planetscale/vtprotobuf/generator/pattern"
)
type ObjectSet struct {
mp map[string]bool
}
func NewObjectSet() ObjectSet {
return ObjectSet{
mp: map[string]bool{},
}
}
func (o ObjectSet) String() string {
return fmt.Sprintf("%#v", o)
}
func (o ObjectSet) Contains(g protogen.GoIdent) bool {
objectPath := fmt.Sprintf("%s.%s", string(g.GoImportPath), g.GoName)
for wildcard := range o.mp {
// Ignore malformed pattern error because pattern already checked in Set
if ok, _ := pattern.Match(wildcard, objectPath); ok {
return true
}
}
return false
}
func (o ObjectSet) Set(s string) error {
if !pattern.ValidatePattern(s) {
return pattern.ErrBadPattern
}
o.mp[s] = true
return nil
}
type Config struct {
// Poolable rules determines if pool feature generate for particular message
Poolable ObjectSet
// PoolableExclude rules determines if pool feature disabled for particular message
PoolableExclude ObjectSet
Wrap bool
WellKnownTypes bool
AllowEmpty bool
BuildTag string
}
type Generator struct {
plugin *protogen.Plugin
cfg *Config
features []Feature
local map[protoreflect.FullName]bool
}
const SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
func NewGenerator(plugin *protogen.Plugin, featureNames []string, cfg *Config) (*Generator, error) {
plugin.SupportedFeatures = SupportedFeatures
features, err := findFeatures(featureNames)
if err != nil {
return nil, err
}
local := make(map[protoreflect.FullName]bool)
for _, f := range plugin.Files {
if f.Generate {
local[f.Desc.Package()] = true
}
}
return &Generator{
plugin: plugin,
cfg: cfg,
features: features,
local: local,
}, nil
}
func (gen *Generator) Generate() {
for _, file := range gen.plugin.Files {
if !file.Generate {
continue
}
var importPath protogen.GoImportPath
if !gen.cfg.Wrap {
importPath = file.GoImportPath
}
gf := gen.plugin.NewGeneratedFile(file.GeneratedFilenamePrefix+"_vtproto.pb.go", importPath)
gen.generateFile(gf, file)
}
}
func (gen *Generator) generateFile(gf *protogen.GeneratedFile, file *protogen.File) {
p := &GeneratedFile{
GeneratedFile: gf,
Config: gen.cfg,
LocalPackages: gen.local,
}
if p.Config.BuildTag != "" {
// Support both forms of tags for maximum compatibility
p.P("//go:build ", p.Config.BuildTag)
p.P("// +build ", p.Config.BuildTag)
}
p.P("// Code generated by protoc-gen-go-vtproto. DO NOT EDIT.")
if bi, ok := debug.ReadBuildInfo(); ok {
p.P("// protoc-gen-go-vtproto version: ", bi.Main.Version)
}
p.P("// source: ", file.Desc.Path())
p.P()
p.P("package ", file.GoPackageName)
p.P()
protoimplPackage := protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl")
p.P("const (")
p.P("// Verify that this generated code is sufficiently up-to-date.")
p.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")")
p.P("// Verify that runtime/protoimpl is sufficiently up-to-date.")
p.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")")
p.P(")")
p.P()
if p.Wrapper() {
for _, msg := range file.Messages {
p.P(`type `, msg.GoIdent.GoName, ` `, msg.GoIdent)
for _, one := range msg.Oneofs {
for _, field := range one.Fields {
p.P(`type `, field.GoIdent.GoName, ` `, field.GoIdent)
}
}
}
}
var generated bool
for _, feat := range gen.features {
featGenerator := feat(p)
if featGenerator.GenerateFile(file) {
generated = true
}
}
if !generated && !gen.cfg.AllowEmpty {
gf.Skip()
}
}

View File

@ -0,0 +1,47 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
)
const ProtoPkg = "google.golang.org/protobuf/proto"
func KeySize(fieldNumber protoreflect.FieldNumber, wireType protowire.Type) int {
x := uint32(fieldNumber)<<3 | uint32(wireType)
size := 0
for size = 0; x > 127; size++ {
x >>= 7
}
size++
return size
}
var wireTypes = map[protoreflect.Kind]protowire.Type{
protoreflect.BoolKind: protowire.VarintType,
protoreflect.EnumKind: protowire.VarintType,
protoreflect.Int32Kind: protowire.VarintType,
protoreflect.Sint32Kind: protowire.VarintType,
protoreflect.Uint32Kind: protowire.VarintType,
protoreflect.Int64Kind: protowire.VarintType,
protoreflect.Sint64Kind: protowire.VarintType,
protoreflect.Uint64Kind: protowire.VarintType,
protoreflect.Sfixed32Kind: protowire.Fixed32Type,
protoreflect.Fixed32Kind: protowire.Fixed32Type,
protoreflect.FloatKind: protowire.Fixed32Type,
protoreflect.Sfixed64Kind: protowire.Fixed64Type,
protoreflect.Fixed64Kind: protowire.Fixed64Type,
protoreflect.DoubleKind: protowire.Fixed64Type,
protoreflect.StringKind: protowire.BytesType,
protoreflect.BytesKind: protowire.BytesType,
protoreflect.MessageKind: protowire.BytesType,
protoreflect.GroupKind: protowire.StartGroupType,
}
func ProtoWireType(k protoreflect.Kind) protowire.Type {
return wireTypes[k]
}

View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Bob Matcuk
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,457 @@
package pattern
import (
"path"
"unicode/utf8"
)
var ErrBadPattern = path.ErrBadPattern
// Match reports whether name matches the shell pattern.
// The pattern syntax is:
//
// pattern:
// { term }
// term:
// '*' matches any sequence of non-path-separators
// '/**/' matches zero or more directories
// '?' matches any single non-path-separator character
// '[' [ '^' '!' ] { character-range } ']'
// character class (must be non-empty)
// starting with `^` or `!` negates the class
// '{' { term } [ ',' { term } ... ] '}'
// alternatives
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
//
// Match returns true if `name` matches the file name `pattern`. `name` and
// `pattern` are split on forward slash (`/`) characters and may be relative or
// absolute.
//
// Match requires pattern to match all of name, not just a substring.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
//
// A doublestar (`**`) should appear surrounded by path separators such as
// `/**/`. A mid-pattern doublestar (`**`) behaves like bash's globstar
// option: a pattern such as `path/to/**.txt` would return the same results as
// `path/to/*.txt`. The pattern you're looking for is `path/to/**/*.txt`.
//
// Note: this is meant as a drop-in replacement for path.Match() which
// always uses '/' as the path separator. If you want to support systems
// which use a different path separator (such as Windows), what you want
// is PathMatch(). Alternatively, you can run filepath.ToSlash() on both
// pattern and name and then use this function.
//
// Note: users should _not_ count on the returned error,
// doublestar.ErrBadPattern, being equal to path.ErrBadPattern.
func Match(pattern, name string) (bool, error) {
return matchWithSeparator(pattern, name, '/', true)
}
func matchWithSeparator(pattern, name string, separator rune, validate bool) (matched bool, err error) {
return doMatchWithSeparator(pattern, name, separator, validate, -1, -1, -1, -1, 0, 0)
}
func doMatchWithSeparator(pattern, name string, separator rune, validate bool, doublestarPatternBacktrack, doublestarNameBacktrack, starPatternBacktrack, starNameBacktrack, patIdx, nameIdx int) (matched bool, err error) {
patLen := len(pattern)
nameLen := len(name)
startOfSegment := true
MATCH:
for nameIdx < nameLen {
if patIdx < patLen {
switch pattern[patIdx] {
case '*':
if patIdx++; patIdx < patLen && pattern[patIdx] == '*' {
// doublestar - must begin with a path separator, otherwise we'll
// treat it like a single star like bash
patIdx++
if startOfSegment {
if patIdx >= patLen {
// pattern ends in `/**`: return true
return true, nil
}
// doublestar must also end with a path separator, otherwise we're
// just going to treat the doublestar as a single star like bash
patRune, patRuneLen := utf8.DecodeRuneInString(pattern[patIdx:])
if patRune == separator {
patIdx += patRuneLen
doublestarPatternBacktrack = patIdx
doublestarNameBacktrack = nameIdx
starPatternBacktrack = -1
starNameBacktrack = -1
continue
}
}
}
startOfSegment = false
starPatternBacktrack = patIdx
starNameBacktrack = nameIdx
continue
case '?':
startOfSegment = false
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[nameIdx:])
if nameRune == separator {
// `?` cannot match the separator
break
}
patIdx++
nameIdx += nameRuneLen
continue
case '[':
startOfSegment = false
if patIdx++; patIdx >= patLen {
// class didn't end
return false, ErrBadPattern
}
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[nameIdx:])
matched := false
negate := pattern[patIdx] == '!' || pattern[patIdx] == '^'
if negate {
patIdx++
}
if patIdx >= patLen || pattern[patIdx] == ']' {
// class didn't end or empty character class
return false, ErrBadPattern
}
last := utf8.MaxRune
for patIdx < patLen && pattern[patIdx] != ']' {
patRune, patRuneLen := utf8.DecodeRuneInString(pattern[patIdx:])
patIdx += patRuneLen
// match a range
if last < utf8.MaxRune && patRune == '-' && patIdx < patLen && pattern[patIdx] != ']' {
if pattern[patIdx] == '\\' {
// next character is escaped
patIdx++
}
patRune, patRuneLen = utf8.DecodeRuneInString(pattern[patIdx:])
patIdx += patRuneLen
if last <= nameRune && nameRune <= patRune {
matched = true
break
}
// didn't match range - reset `last`
last = utf8.MaxRune
continue
}
// not a range - check if the next rune is escaped
if patRune == '\\' {
patRune, patRuneLen = utf8.DecodeRuneInString(pattern[patIdx:])
patIdx += patRuneLen
}
// check if the rune matches
if patRune == nameRune {
matched = true
break
}
// no matches yet
last = patRune
}
if matched == negate {
// failed to match - if we reached the end of the pattern, that means
// we never found a closing `]`
if patIdx >= patLen {
return false, ErrBadPattern
}
break
}
closingIdx := indexUnescapedByte(pattern[patIdx:], ']', true)
if closingIdx == -1 {
// no closing `]`
return false, ErrBadPattern
}
patIdx += closingIdx + 1
nameIdx += nameRuneLen
continue
case '{':
startOfSegment = false
beforeIdx := patIdx
patIdx++
closingIdx := indexMatchedClosingAlt(pattern[patIdx:], separator != '\\')
if closingIdx == -1 {
// no closing `}`
return false, ErrBadPattern
}
closingIdx += patIdx
for {
commaIdx := indexNextAlt(pattern[patIdx:closingIdx], separator != '\\')
if commaIdx == -1 {
break
}
commaIdx += patIdx
result, err := doMatchWithSeparator(pattern[:beforeIdx]+pattern[patIdx:commaIdx]+pattern[closingIdx+1:], name, separator, validate, doublestarPatternBacktrack, doublestarNameBacktrack, starPatternBacktrack, starNameBacktrack, beforeIdx, nameIdx)
if result || err != nil {
return result, err
}
patIdx = commaIdx + 1
}
return doMatchWithSeparator(pattern[:beforeIdx]+pattern[patIdx:closingIdx]+pattern[closingIdx+1:], name, separator, validate, doublestarPatternBacktrack, doublestarNameBacktrack, starPatternBacktrack, starNameBacktrack, beforeIdx, nameIdx)
case '\\':
if separator != '\\' {
// next rune is "escaped" in the pattern - literal match
if patIdx++; patIdx >= patLen {
// pattern ended
return false, ErrBadPattern
}
}
fallthrough
default:
patRune, patRuneLen := utf8.DecodeRuneInString(pattern[patIdx:])
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[nameIdx:])
if patRune != nameRune {
if separator != '\\' && patIdx > 0 && pattern[patIdx-1] == '\\' {
// if this rune was meant to be escaped, we need to move patIdx
// back to the backslash before backtracking or validating below
patIdx--
}
break
}
patIdx += patRuneLen
nameIdx += nameRuneLen
startOfSegment = patRune == separator
continue
}
}
if starPatternBacktrack >= 0 {
// `*` backtrack, but only if the `name` rune isn't the separator
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[starNameBacktrack:])
if nameRune != separator {
starNameBacktrack += nameRuneLen
patIdx = starPatternBacktrack
nameIdx = starNameBacktrack
startOfSegment = false
continue
}
}
if doublestarPatternBacktrack >= 0 {
// `**` backtrack, advance `name` past next separator
nameIdx = doublestarNameBacktrack
for nameIdx < nameLen {
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[nameIdx:])
nameIdx += nameRuneLen
if nameRune == separator {
doublestarNameBacktrack = nameIdx
patIdx = doublestarPatternBacktrack
startOfSegment = true
continue MATCH
}
}
}
if validate && patIdx < patLen && !doValidatePattern(pattern[patIdx:], separator) {
return false, ErrBadPattern
}
return false, nil
}
if nameIdx < nameLen {
// we reached the end of `pattern` before the end of `name`
return false, nil
}
// we've reached the end of `name`; we've successfully matched if we've also
// reached the end of `pattern`, or if the rest of `pattern` can match a
// zero-length string
return isZeroLengthPattern(pattern[patIdx:], separator)
}
func isZeroLengthPattern(pattern string, separator rune) (ret bool, err error) {
// `/**`, `**/`, and `/**/` are special cases - a pattern such as `path/to/a/**` or `path/to/a/**/`
// *should* match `path/to/a` because `a` might be a directory
if pattern == "" ||
pattern == "*" ||
pattern == "**" ||
pattern == string(separator)+"**" ||
pattern == "**"+string(separator) ||
pattern == string(separator)+"**"+string(separator) {
return true, nil
}
if pattern[0] == '{' {
closingIdx := indexMatchedClosingAlt(pattern[1:], separator != '\\')
if closingIdx == -1 {
// no closing '}'
return false, ErrBadPattern
}
closingIdx += 1
patIdx := 1
for {
commaIdx := indexNextAlt(pattern[patIdx:closingIdx], separator != '\\')
if commaIdx == -1 {
break
}
commaIdx += patIdx
ret, err = isZeroLengthPattern(pattern[patIdx:commaIdx]+pattern[closingIdx+1:], separator)
if ret || err != nil {
return
}
patIdx = commaIdx + 1
}
return isZeroLengthPattern(pattern[patIdx:closingIdx]+pattern[closingIdx+1:], separator)
}
// no luck - validate the rest of the pattern
if !doValidatePattern(pattern, separator) {
return false, ErrBadPattern
}
return false, nil
}
// Finds the next comma, but ignores any commas that appear inside nested `{}`.
// Assumes that each opening bracket has a corresponding closing bracket.
func indexNextAlt(s string, allowEscaping bool) int {
alts := 1
l := len(s)
for i := 0; i < l; i++ {
if allowEscaping && s[i] == '\\' {
// skip next byte
i++
} else if s[i] == '{' {
alts++
} else if s[i] == '}' {
alts--
} else if s[i] == ',' && alts == 1 {
return i
}
}
return -1
}
// Finds the index of the first unescaped byte `c`, or negative 1.
func indexUnescapedByte(s string, c byte, allowEscaping bool) int {
l := len(s)
for i := 0; i < l; i++ {
if allowEscaping && s[i] == '\\' {
// skip next byte
i++
} else if s[i] == c {
return i
}
}
return -1
}
// Assuming the byte before the beginning of `s` is an opening `{`, this
// function will find the index of the matching `}`. That is, it'll skip over
// any nested `{}` and account for escaping
func indexMatchedClosingAlt(s string, allowEscaping bool) int {
alts := 1
l := len(s)
for i := 0; i < l; i++ {
if allowEscaping && s[i] == '\\' {
// skip next byte
i++
} else if s[i] == '{' {
alts++
} else if s[i] == '}' {
if alts--; alts == 0 {
return i
}
}
}
return -1
}
// Validate a pattern. Patterns are validated while they run in Match(),
// PathMatch(), and Glob(), so, you normally wouldn't need to call this.
// However, there are cases where this might be useful: for example, if your
// program allows a user to enter a pattern that you'll run at a later time,
// you might want to validate it.
//
// ValidatePattern assumes your pattern uses '/' as the path separator.
func ValidatePattern(s string) bool {
return doValidatePattern(s, '/')
}
func doValidatePattern(s string, separator rune) bool {
altDepth := 0
l := len(s)
VALIDATE:
for i := 0; i < l; i++ {
switch s[i] {
case '\\':
if separator != '\\' {
// skip the next byte - return false if there is no next byte
if i++; i >= l {
return false
}
}
continue
case '[':
if i++; i >= l {
// class didn't end
return false
}
if s[i] == '^' || s[i] == '!' {
i++
}
if i >= l || s[i] == ']' {
// class didn't end or empty character class
return false
}
for ; i < l; i++ {
if separator != '\\' && s[i] == '\\' {
i++
} else if s[i] == ']' {
// looks good
continue VALIDATE
}
}
// class didn't end
return false
case '{':
altDepth++
continue
case '}':
if altDepth == 0 {
// alt end without a corresponding start
return false
}
altDepth--
continue
}
}
// valid as long as all alts are closed
return altDepth == 0
}

View File

@ -0,0 +1,122 @@
// Package protohelpers provides helper functions for encoding and decoding protobuf messages.
// The spec can be found at https://protobuf.dev/programming-guides/encoding/.
package protohelpers
import (
"fmt"
"io"
"math/bits"
)
var (
// ErrInvalidLength is returned when decoding a negative length.
ErrInvalidLength = fmt.Errorf("proto: negative length found during unmarshaling")
// ErrIntOverflow is returned when decoding a varint representation of an integer that overflows 64 bits.
ErrIntOverflow = fmt.Errorf("proto: integer overflow")
// ErrUnexpectedEndOfGroup is returned when decoding a group end without a corresponding group start.
ErrUnexpectedEndOfGroup = fmt.Errorf("proto: unexpected end of group")
)
// EncodeVarint encodes a uint64 into a varint-encoded byte slice and returns the offset of the encoded value.
// The provided offset is the offset after the last byte of the encoded value.
func EncodeVarint(dAtA []byte, offset int, v uint64) int {
offset -= SizeOfVarint(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
// SizeOfVarint returns the size of the varint-encoded value.
func SizeOfVarint(x uint64) (n int) {
return (bits.Len64(x|1) + 6) / 7
}
// SizeOfZigzag returns the size of the zigzag-encoded value.
func SizeOfZigzag(x uint64) (n int) {
return SizeOfVarint(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
// Skip the first record of the byte slice and return the offset of the next record.
func Skip(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
depth := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflow
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflow
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
case 1:
iNdEx += 8
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflow
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLength
}
iNdEx += length
case 3:
depth++
case 4:
if depth == 0 {
return 0, ErrUnexpectedEndOfGroup
}
depth--
case 5:
iNdEx += 4
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
if iNdEx < 0 {
return 0, ErrInvalidLength
}
if depth == 0 {
return iNdEx, nil
}
}
return 0, io.ErrUnexpectedEOF
}

View File

@ -0,0 +1,317 @@
// Code generated by protoc-gen-go-vtproto. DO NOT EDIT.
// protoc-gen-go-vtproto version: (devel)
// source: google/protobuf/timestamp.proto
package timestamppb
import (
fmt "fmt"
protohelpers "github.com/planetscale/vtprotobuf/protohelpers"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
io "io"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Timestamp timestamppb.Timestamp
func (m *Timestamp) CloneVT() *Timestamp {
if m == nil {
return (*Timestamp)(nil)
}
r := new(Timestamp)
r.Seconds = m.Seconds
r.Nanos = m.Nanos
return r
}
func (this *Timestamp) EqualVT(that *Timestamp) bool {
if this == that {
return true
} else if this == nil || that == nil {
return false
}
if this.Seconds != that.Seconds {
return false
}
if this.Nanos != that.Nanos {
return false
}
return true
}
func (m *Timestamp) MarshalVT() (dAtA []byte, err error) {
if m == nil {
return nil, nil
}
size := m.SizeVT()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBufferVT(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Timestamp) MarshalToVT(dAtA []byte) (int, error) {
size := m.SizeVT()
return m.MarshalToSizedBufferVT(dAtA[:size])
}
func (m *Timestamp) MarshalToSizedBufferVT(dAtA []byte) (int, error) {
if m == nil {
return 0, nil
}
i := len(dAtA)
_ = i
var l int
_ = l
if m.Nanos != 0 {
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Nanos))
i--
dAtA[i] = 0x10
}
if m.Seconds != 0 {
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Seconds))
i--
dAtA[i] = 0x8
}
return len(dAtA) - i, nil
}
func (m *Timestamp) MarshalVTStrict() (dAtA []byte, err error) {
if m == nil {
return nil, nil
}
size := m.SizeVT()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBufferVTStrict(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Timestamp) MarshalToVTStrict(dAtA []byte) (int, error) {
size := m.SizeVT()
return m.MarshalToSizedBufferVTStrict(dAtA[:size])
}
func (m *Timestamp) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error) {
if m == nil {
return 0, nil
}
i := len(dAtA)
_ = i
var l int
_ = l
if m.Nanos != 0 {
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Nanos))
i--
dAtA[i] = 0x10
}
if m.Seconds != 0 {
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Seconds))
i--
dAtA[i] = 0x8
}
return len(dAtA) - i, nil
}
func (m *Timestamp) SizeVT() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if m.Seconds != 0 {
n += 1 + protohelpers.SizeOfVarint(uint64(m.Seconds))
}
if m.Nanos != 0 {
n += 1 + protohelpers.SizeOfVarint(uint64(m.Nanos))
}
return n
}
func (m *Timestamp) UnmarshalVT(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Timestamp: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Timestamp: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Seconds", wireType)
}
m.Seconds = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Seconds |= int64(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Nanos", wireType)
}
m.Nanos = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Nanos |= int32(b&0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := protohelpers.Skip(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return protohelpers.ErrInvalidLength
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *Timestamp) UnmarshalVTUnsafe(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Timestamp: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Timestamp: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Seconds", wireType)
}
m.Seconds = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Seconds |= int64(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Nanos", wireType)
}
m.Nanos = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Nanos |= int32(b&0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := protohelpers.Skip(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return protohelpers.ErrInvalidLength
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}

View File

@ -0,0 +1,95 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.31.0
// protoc v3.21.12
// source: github.com/planetscale/vtprotobuf/vtproto/ext.proto
package vtproto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
descriptorpb "google.golang.org/protobuf/types/descriptorpb"
reflect "reflect"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
var file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_extTypes = []protoimpl.ExtensionInfo{
{
ExtendedType: (*descriptorpb.MessageOptions)(nil),
ExtensionType: (*bool)(nil),
Field: 64101,
Name: "vtproto.mempool",
Tag: "varint,64101,opt,name=mempool",
Filename: "github.com/planetscale/vtprotobuf/vtproto/ext.proto",
},
}
// Extension fields to descriptorpb.MessageOptions.
var (
// optional bool mempool = 64101;
E_Mempool = &file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_extTypes[0]
)
var File_github_com_planetscale_vtprotobuf_vtproto_ext_proto protoreflect.FileDescriptor
var file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_rawDesc = []byte{
0x0a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6c, 0x61,
0x6e, 0x65, 0x74, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x62, 0x75, 0x66, 0x2f, 0x76, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x65, 0x78, 0x74, 0x2e,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x76, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x20,
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f,
0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x3a, 0x3b, 0x0a, 0x07, 0x6d, 0x65, 0x6d, 0x70, 0x6f, 0x6f, 0x6c, 0x12, 0x1f, 0x2e, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65,
0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xe5, 0xf4, 0x03,
0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x6d, 0x65, 0x6d, 0x70, 0x6f, 0x6f, 0x6c, 0x42, 0x49, 0x0a,
0x13, 0x63, 0x6f, 0x6d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x75, 0x66, 0x42, 0x07, 0x56, 0x54, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x5a, 0x29, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x65, 0x74,
0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
0x2f, 0x76, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f,
}
var file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_goTypes = []interface{}{
(*descriptorpb.MessageOptions)(nil), // 0: google.protobuf.MessageOptions
}
var file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_depIdxs = []int32{
0, // 0: vtproto.mempool:extendee -> google.protobuf.MessageOptions
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
0, // [0:1] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_init() }
func file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_init() {
if File_github_com_planetscale_vtprotobuf_vtproto_ext_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_rawDesc,
NumEnums: 0,
NumMessages: 0,
NumExtensions: 1,
NumServices: 0,
},
GoTypes: file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_goTypes,
DependencyIndexes: file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_depIdxs,
ExtensionInfos: file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_extTypes,
}.Build()
File_github_com_planetscale_vtprotobuf_vtproto_ext_proto = out.File
file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_rawDesc = nil
file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_goTypes = nil
file_github_com_planetscale_vtprotobuf_vtproto_ext_proto_depIdxs = nil
}