hack: generate vtproto files for buildx

Integrates vtproto into buildx. The generated files dockerfile has been
modified to copy the buildkit equivalent file to ensure files are laid
out in the appropriate way for imports.

An import has also been included to change the grpc codec to the version
in buildkit that supports vtproto. This will allow buildx to utilize the
speed and memory improvements from that.

Also updates the gc control options for prune.

Signed-off-by: Jonathan A. Sternberg <jonathan.sternberg@docker.com>
This commit is contained in:
Jonathan A. Sternberg
2024-10-08 13:35:06 -05:00
parent d353f5f6ba
commit 64c5139ab6
109 changed files with 68070 additions and 2941 deletions

View File

@ -0,0 +1,58 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"fmt"
"sort"
"google.golang.org/protobuf/compiler/protogen"
)
var defaultFeatures = make(map[string]Feature)
func findFeatures(featureNames []string) ([]Feature, error) {
required := make(map[string]Feature)
for _, name := range featureNames {
if name == "all" {
required = defaultFeatures
break
}
feat, ok := defaultFeatures[name]
if !ok {
return nil, fmt.Errorf("unknown feature: %q", name)
}
required[name] = feat
}
type namefeat struct {
name string
feat Feature
}
var sorted []namefeat
for name, feat := range required {
sorted = append(sorted, namefeat{name, feat})
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].name < sorted[j].name
})
var features []Feature
for _, sp := range sorted {
features = append(features, sp.feat)
}
return features, nil
}
func RegisterFeature(name string, feat Feature) {
defaultFeatures[name] = feat
}
type Feature func(gen *GeneratedFile) FeatureGenerator
type FeatureGenerator interface {
GenerateFile(file *protogen.File) bool
}

View File

@ -0,0 +1,198 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"fmt"
"github.com/planetscale/vtprotobuf/vtproto"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
type GeneratedFile struct {
*protogen.GeneratedFile
Config *Config
LocalPackages map[protoreflect.FullName]bool
}
func (p *GeneratedFile) Ident(path, ident string) string {
return p.QualifiedGoIdent(protogen.GoImportPath(path).Ident(ident))
}
func (b *GeneratedFile) ShouldPool(message *protogen.Message) bool {
// Do not generate pool if message is nil or message excluded by external rules
if message == nil || b.Config.PoolableExclude.Contains(message.GoIdent) {
return false
}
if b.Config.Poolable.Contains(message.GoIdent) {
return true
}
ext := proto.GetExtension(message.Desc.Options(), vtproto.E_Mempool)
if mempool, ok := ext.(bool); ok {
return mempool
}
return false
}
func (b *GeneratedFile) Alloc(vname string, message *protogen.Message, isQualifiedIdent bool) {
ident := message.GoIdent.GoName
if isQualifiedIdent {
ident = b.QualifiedGoIdent(message.GoIdent)
}
if b.ShouldPool(message) {
b.P(vname, " := ", ident, `FromVTPool()`)
} else {
b.P(vname, " := new(", ident, `)`)
}
}
func (p *GeneratedFile) FieldGoType(field *protogen.Field) (goType string, pointer bool) {
if field.Desc.IsWeak() {
return "struct{}", false
}
pointer = field.Desc.HasPresence()
switch field.Desc.Kind() {
case protoreflect.BoolKind:
goType = "bool"
case protoreflect.EnumKind:
goType = p.QualifiedGoIdent(field.Enum.GoIdent)
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
goType = "int32"
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
goType = "uint32"
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
goType = "int64"
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
goType = "uint64"
case protoreflect.FloatKind:
goType = "float32"
case protoreflect.DoubleKind:
goType = "float64"
case protoreflect.StringKind:
goType = "string"
case protoreflect.BytesKind:
goType = "[]byte"
pointer = false // rely on nullability of slices for presence
case protoreflect.MessageKind, protoreflect.GroupKind:
goType = "*" + p.QualifiedGoIdent(field.Message.GoIdent)
pointer = false // pointer captured as part of the type
}
switch {
case field.Desc.IsList():
return "[]" + goType, false
case field.Desc.IsMap():
keyType, _ := p.FieldGoType(field.Message.Fields[0])
valType, _ := p.FieldGoType(field.Message.Fields[1])
return fmt.Sprintf("map[%v]%v", keyType, valType), false
}
return goType, pointer
}
func (p *GeneratedFile) IsLocalMessage(message *protogen.Message) bool {
if message == nil {
return false
}
pkg := message.Desc.ParentFile().Package()
return p.LocalPackages[pkg]
}
func (p *GeneratedFile) IsLocalField(field *protogen.Field) bool {
if field == nil {
return false
}
pkg := field.Desc.ParentFile().Package()
return p.LocalPackages[pkg]
}
const vtHelpersPackage = protogen.GoImportPath("github.com/planetscale/vtprotobuf/protohelpers")
var helpers = map[string]protogen.GoIdent{
"EncodeVarint": {GoName: "EncodeVarint", GoImportPath: vtHelpersPackage},
"SizeOfVarint": {GoName: "SizeOfVarint", GoImportPath: vtHelpersPackage},
"SizeOfZigzag": {GoName: "SizeOfZigzag", GoImportPath: vtHelpersPackage},
"Skip": {GoName: "Skip", GoImportPath: vtHelpersPackage},
"ErrInvalidLength": {GoName: "ErrInvalidLength", GoImportPath: vtHelpersPackage},
"ErrIntOverflow": {GoName: "ErrIntOverflow", GoImportPath: vtHelpersPackage},
"ErrUnexpectedEndOfGroup": {GoName: "ErrUnexpectedEndOfGroup", GoImportPath: vtHelpersPackage},
}
func (p *GeneratedFile) Helper(name string) protogen.GoIdent {
return helpers[name]
}
const vtWellKnownPackage = protogen.GoImportPath("github.com/planetscale/vtprotobuf/types/known/")
var wellKnownTypes = map[protoreflect.FullName]protogen.GoIdent{
"google.protobuf.Any": {GoName: "Any", GoImportPath: vtWellKnownPackage + "anypb"},
"google.protobuf.Duration": {GoName: "Duration", GoImportPath: vtWellKnownPackage + "durationpb"},
"google.protobuf.Empty": {GoName: "Empty", GoImportPath: vtWellKnownPackage + "emptypb"},
"google.protobuf.FieldMask": {GoName: "FieldMask", GoImportPath: vtWellKnownPackage + "fieldmaskpb"},
"google.protobuf.Timestamp": {GoName: "Timestamp", GoImportPath: vtWellKnownPackage + "timestamppb"},
"google.protobuf.DoubleValue": {GoName: "DoubleValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.FloatValue": {GoName: "FloatValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.Int64Value": {GoName: "Int64Value", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.UInt64Value": {GoName: "UInt64Value", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.Int32Value": {GoName: "Int32Value", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.UInt32Value": {GoName: "UInt32Value", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.BoolValue": {GoName: "BoolValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.StringValue": {GoName: "StringValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.BytesValue": {GoName: "BytesValue", GoImportPath: vtWellKnownPackage + "wrapperspb"},
"google.protobuf.Struct": {GoName: "Struct", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value": {GoName: "Value", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.ListValue": {GoName: "ListValue", GoImportPath: vtWellKnownPackage + "structpb"},
}
var wellKnownFields = map[protoreflect.FullName]protogen.GoIdent{
"google.protobuf.Value.null_value": {GoName: "Value_NullValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.number_value": {GoName: "Value_NumberValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.string_value": {GoName: "Value_StringValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.bool_value": {GoName: "Value_BoolValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.struct_value": {GoName: "Value_StructValue", GoImportPath: vtWellKnownPackage + "structpb"},
"google.protobuf.Value.list_value": {GoName: "Value_ListValue", GoImportPath: vtWellKnownPackage + "structpb"},
}
func (p *GeneratedFile) IsWellKnownType(message *protogen.Message) bool {
if message == nil {
return false
}
_, ok := wellKnownTypes[message.Desc.FullName()]
return ok
}
func (p *GeneratedFile) WellKnownFieldMap(field *protogen.Field) protogen.GoIdent {
if field == nil {
return protogen.GoIdent{}
}
res, ff := wellKnownFields[field.Desc.FullName()]
if !ff {
panic(field.Desc.FullName())
}
if p.IsLocalField(field) {
res.GoImportPath = ""
}
return res
}
func (p *GeneratedFile) WellKnownTypeMap(message *protogen.Message) protogen.GoIdent {
if message == nil {
return protogen.GoIdent{}
}
res := wellKnownTypes[message.Desc.FullName()]
if p.IsLocalMessage(message) {
res.GoImportPath = ""
}
return res
}
func (p *GeneratedFile) Wrapper() bool {
return p.Config.Wrap
}

View File

@ -0,0 +1,165 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"fmt"
"runtime/debug"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/types/pluginpb"
"github.com/planetscale/vtprotobuf/generator/pattern"
)
type ObjectSet struct {
mp map[string]bool
}
func NewObjectSet() ObjectSet {
return ObjectSet{
mp: map[string]bool{},
}
}
func (o ObjectSet) String() string {
return fmt.Sprintf("%#v", o)
}
func (o ObjectSet) Contains(g protogen.GoIdent) bool {
objectPath := fmt.Sprintf("%s.%s", string(g.GoImportPath), g.GoName)
for wildcard := range o.mp {
// Ignore malformed pattern error because pattern already checked in Set
if ok, _ := pattern.Match(wildcard, objectPath); ok {
return true
}
}
return false
}
func (o ObjectSet) Set(s string) error {
if !pattern.ValidatePattern(s) {
return pattern.ErrBadPattern
}
o.mp[s] = true
return nil
}
type Config struct {
// Poolable rules determines if pool feature generate for particular message
Poolable ObjectSet
// PoolableExclude rules determines if pool feature disabled for particular message
PoolableExclude ObjectSet
Wrap bool
WellKnownTypes bool
AllowEmpty bool
BuildTag string
}
type Generator struct {
plugin *protogen.Plugin
cfg *Config
features []Feature
local map[protoreflect.FullName]bool
}
const SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
func NewGenerator(plugin *protogen.Plugin, featureNames []string, cfg *Config) (*Generator, error) {
plugin.SupportedFeatures = SupportedFeatures
features, err := findFeatures(featureNames)
if err != nil {
return nil, err
}
local := make(map[protoreflect.FullName]bool)
for _, f := range plugin.Files {
if f.Generate {
local[f.Desc.Package()] = true
}
}
return &Generator{
plugin: plugin,
cfg: cfg,
features: features,
local: local,
}, nil
}
func (gen *Generator) Generate() {
for _, file := range gen.plugin.Files {
if !file.Generate {
continue
}
var importPath protogen.GoImportPath
if !gen.cfg.Wrap {
importPath = file.GoImportPath
}
gf := gen.plugin.NewGeneratedFile(file.GeneratedFilenamePrefix+"_vtproto.pb.go", importPath)
gen.generateFile(gf, file)
}
}
func (gen *Generator) generateFile(gf *protogen.GeneratedFile, file *protogen.File) {
p := &GeneratedFile{
GeneratedFile: gf,
Config: gen.cfg,
LocalPackages: gen.local,
}
if p.Config.BuildTag != "" {
// Support both forms of tags for maximum compatibility
p.P("//go:build ", p.Config.BuildTag)
p.P("// +build ", p.Config.BuildTag)
}
p.P("// Code generated by protoc-gen-go-vtproto. DO NOT EDIT.")
if bi, ok := debug.ReadBuildInfo(); ok {
p.P("// protoc-gen-go-vtproto version: ", bi.Main.Version)
}
p.P("// source: ", file.Desc.Path())
p.P()
p.P("package ", file.GoPackageName)
p.P()
protoimplPackage := protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl")
p.P("const (")
p.P("// Verify that this generated code is sufficiently up-to-date.")
p.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")")
p.P("// Verify that runtime/protoimpl is sufficiently up-to-date.")
p.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")")
p.P(")")
p.P()
if p.Wrapper() {
for _, msg := range file.Messages {
p.P(`type `, msg.GoIdent.GoName, ` `, msg.GoIdent)
for _, one := range msg.Oneofs {
for _, field := range one.Fields {
p.P(`type `, field.GoIdent.GoName, ` `, field.GoIdent)
}
}
}
}
var generated bool
for _, feat := range gen.features {
featGenerator := feat(p)
if featGenerator.GenerateFile(file) {
generated = true
}
}
if !generated && !gen.cfg.AllowEmpty {
gf.Skip()
}
}

View File

@ -0,0 +1,47 @@
// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
)
const ProtoPkg = "google.golang.org/protobuf/proto"
func KeySize(fieldNumber protoreflect.FieldNumber, wireType protowire.Type) int {
x := uint32(fieldNumber)<<3 | uint32(wireType)
size := 0
for size = 0; x > 127; size++ {
x >>= 7
}
size++
return size
}
var wireTypes = map[protoreflect.Kind]protowire.Type{
protoreflect.BoolKind: protowire.VarintType,
protoreflect.EnumKind: protowire.VarintType,
protoreflect.Int32Kind: protowire.VarintType,
protoreflect.Sint32Kind: protowire.VarintType,
protoreflect.Uint32Kind: protowire.VarintType,
protoreflect.Int64Kind: protowire.VarintType,
protoreflect.Sint64Kind: protowire.VarintType,
protoreflect.Uint64Kind: protowire.VarintType,
protoreflect.Sfixed32Kind: protowire.Fixed32Type,
protoreflect.Fixed32Kind: protowire.Fixed32Type,
protoreflect.FloatKind: protowire.Fixed32Type,
protoreflect.Sfixed64Kind: protowire.Fixed64Type,
protoreflect.Fixed64Kind: protowire.Fixed64Type,
protoreflect.DoubleKind: protowire.Fixed64Type,
protoreflect.StringKind: protowire.BytesType,
protoreflect.BytesKind: protowire.BytesType,
protoreflect.MessageKind: protowire.BytesType,
protoreflect.GroupKind: protowire.StartGroupType,
}
func ProtoWireType(k protoreflect.Kind) protowire.Type {
return wireTypes[k]
}

View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Bob Matcuk
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,457 @@
package pattern
import (
"path"
"unicode/utf8"
)
var ErrBadPattern = path.ErrBadPattern
// Match reports whether name matches the shell pattern.
// The pattern syntax is:
//
// pattern:
// { term }
// term:
// '*' matches any sequence of non-path-separators
// '/**/' matches zero or more directories
// '?' matches any single non-path-separator character
// '[' [ '^' '!' ] { character-range } ']'
// character class (must be non-empty)
// starting with `^` or `!` negates the class
// '{' { term } [ ',' { term } ... ] '}'
// alternatives
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
//
// Match returns true if `name` matches the file name `pattern`. `name` and
// `pattern` are split on forward slash (`/`) characters and may be relative or
// absolute.
//
// Match requires pattern to match all of name, not just a substring.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
//
// A doublestar (`**`) should appear surrounded by path separators such as
// `/**/`. A mid-pattern doublestar (`**`) behaves like bash's globstar
// option: a pattern such as `path/to/**.txt` would return the same results as
// `path/to/*.txt`. The pattern you're looking for is `path/to/**/*.txt`.
//
// Note: this is meant as a drop-in replacement for path.Match() which
// always uses '/' as the path separator. If you want to support systems
// which use a different path separator (such as Windows), what you want
// is PathMatch(). Alternatively, you can run filepath.ToSlash() on both
// pattern and name and then use this function.
//
// Note: users should _not_ count on the returned error,
// doublestar.ErrBadPattern, being equal to path.ErrBadPattern.
func Match(pattern, name string) (bool, error) {
return matchWithSeparator(pattern, name, '/', true)
}
func matchWithSeparator(pattern, name string, separator rune, validate bool) (matched bool, err error) {
return doMatchWithSeparator(pattern, name, separator, validate, -1, -1, -1, -1, 0, 0)
}
func doMatchWithSeparator(pattern, name string, separator rune, validate bool, doublestarPatternBacktrack, doublestarNameBacktrack, starPatternBacktrack, starNameBacktrack, patIdx, nameIdx int) (matched bool, err error) {
patLen := len(pattern)
nameLen := len(name)
startOfSegment := true
MATCH:
for nameIdx < nameLen {
if patIdx < patLen {
switch pattern[patIdx] {
case '*':
if patIdx++; patIdx < patLen && pattern[patIdx] == '*' {
// doublestar - must begin with a path separator, otherwise we'll
// treat it like a single star like bash
patIdx++
if startOfSegment {
if patIdx >= patLen {
// pattern ends in `/**`: return true
return true, nil
}
// doublestar must also end with a path separator, otherwise we're
// just going to treat the doublestar as a single star like bash
patRune, patRuneLen := utf8.DecodeRuneInString(pattern[patIdx:])
if patRune == separator {
patIdx += patRuneLen
doublestarPatternBacktrack = patIdx
doublestarNameBacktrack = nameIdx
starPatternBacktrack = -1
starNameBacktrack = -1
continue
}
}
}
startOfSegment = false
starPatternBacktrack = patIdx
starNameBacktrack = nameIdx
continue
case '?':
startOfSegment = false
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[nameIdx:])
if nameRune == separator {
// `?` cannot match the separator
break
}
patIdx++
nameIdx += nameRuneLen
continue
case '[':
startOfSegment = false
if patIdx++; patIdx >= patLen {
// class didn't end
return false, ErrBadPattern
}
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[nameIdx:])
matched := false
negate := pattern[patIdx] == '!' || pattern[patIdx] == '^'
if negate {
patIdx++
}
if patIdx >= patLen || pattern[patIdx] == ']' {
// class didn't end or empty character class
return false, ErrBadPattern
}
last := utf8.MaxRune
for patIdx < patLen && pattern[patIdx] != ']' {
patRune, patRuneLen := utf8.DecodeRuneInString(pattern[patIdx:])
patIdx += patRuneLen
// match a range
if last < utf8.MaxRune && patRune == '-' && patIdx < patLen && pattern[patIdx] != ']' {
if pattern[patIdx] == '\\' {
// next character is escaped
patIdx++
}
patRune, patRuneLen = utf8.DecodeRuneInString(pattern[patIdx:])
patIdx += patRuneLen
if last <= nameRune && nameRune <= patRune {
matched = true
break
}
// didn't match range - reset `last`
last = utf8.MaxRune
continue
}
// not a range - check if the next rune is escaped
if patRune == '\\' {
patRune, patRuneLen = utf8.DecodeRuneInString(pattern[patIdx:])
patIdx += patRuneLen
}
// check if the rune matches
if patRune == nameRune {
matched = true
break
}
// no matches yet
last = patRune
}
if matched == negate {
// failed to match - if we reached the end of the pattern, that means
// we never found a closing `]`
if patIdx >= patLen {
return false, ErrBadPattern
}
break
}
closingIdx := indexUnescapedByte(pattern[patIdx:], ']', true)
if closingIdx == -1 {
// no closing `]`
return false, ErrBadPattern
}
patIdx += closingIdx + 1
nameIdx += nameRuneLen
continue
case '{':
startOfSegment = false
beforeIdx := patIdx
patIdx++
closingIdx := indexMatchedClosingAlt(pattern[patIdx:], separator != '\\')
if closingIdx == -1 {
// no closing `}`
return false, ErrBadPattern
}
closingIdx += patIdx
for {
commaIdx := indexNextAlt(pattern[patIdx:closingIdx], separator != '\\')
if commaIdx == -1 {
break
}
commaIdx += patIdx
result, err := doMatchWithSeparator(pattern[:beforeIdx]+pattern[patIdx:commaIdx]+pattern[closingIdx+1:], name, separator, validate, doublestarPatternBacktrack, doublestarNameBacktrack, starPatternBacktrack, starNameBacktrack, beforeIdx, nameIdx)
if result || err != nil {
return result, err
}
patIdx = commaIdx + 1
}
return doMatchWithSeparator(pattern[:beforeIdx]+pattern[patIdx:closingIdx]+pattern[closingIdx+1:], name, separator, validate, doublestarPatternBacktrack, doublestarNameBacktrack, starPatternBacktrack, starNameBacktrack, beforeIdx, nameIdx)
case '\\':
if separator != '\\' {
// next rune is "escaped" in the pattern - literal match
if patIdx++; patIdx >= patLen {
// pattern ended
return false, ErrBadPattern
}
}
fallthrough
default:
patRune, patRuneLen := utf8.DecodeRuneInString(pattern[patIdx:])
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[nameIdx:])
if patRune != nameRune {
if separator != '\\' && patIdx > 0 && pattern[patIdx-1] == '\\' {
// if this rune was meant to be escaped, we need to move patIdx
// back to the backslash before backtracking or validating below
patIdx--
}
break
}
patIdx += patRuneLen
nameIdx += nameRuneLen
startOfSegment = patRune == separator
continue
}
}
if starPatternBacktrack >= 0 {
// `*` backtrack, but only if the `name` rune isn't the separator
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[starNameBacktrack:])
if nameRune != separator {
starNameBacktrack += nameRuneLen
patIdx = starPatternBacktrack
nameIdx = starNameBacktrack
startOfSegment = false
continue
}
}
if doublestarPatternBacktrack >= 0 {
// `**` backtrack, advance `name` past next separator
nameIdx = doublestarNameBacktrack
for nameIdx < nameLen {
nameRune, nameRuneLen := utf8.DecodeRuneInString(name[nameIdx:])
nameIdx += nameRuneLen
if nameRune == separator {
doublestarNameBacktrack = nameIdx
patIdx = doublestarPatternBacktrack
startOfSegment = true
continue MATCH
}
}
}
if validate && patIdx < patLen && !doValidatePattern(pattern[patIdx:], separator) {
return false, ErrBadPattern
}
return false, nil
}
if nameIdx < nameLen {
// we reached the end of `pattern` before the end of `name`
return false, nil
}
// we've reached the end of `name`; we've successfully matched if we've also
// reached the end of `pattern`, or if the rest of `pattern` can match a
// zero-length string
return isZeroLengthPattern(pattern[patIdx:], separator)
}
func isZeroLengthPattern(pattern string, separator rune) (ret bool, err error) {
// `/**`, `**/`, and `/**/` are special cases - a pattern such as `path/to/a/**` or `path/to/a/**/`
// *should* match `path/to/a` because `a` might be a directory
if pattern == "" ||
pattern == "*" ||
pattern == "**" ||
pattern == string(separator)+"**" ||
pattern == "**"+string(separator) ||
pattern == string(separator)+"**"+string(separator) {
return true, nil
}
if pattern[0] == '{' {
closingIdx := indexMatchedClosingAlt(pattern[1:], separator != '\\')
if closingIdx == -1 {
// no closing '}'
return false, ErrBadPattern
}
closingIdx += 1
patIdx := 1
for {
commaIdx := indexNextAlt(pattern[patIdx:closingIdx], separator != '\\')
if commaIdx == -1 {
break
}
commaIdx += patIdx
ret, err = isZeroLengthPattern(pattern[patIdx:commaIdx]+pattern[closingIdx+1:], separator)
if ret || err != nil {
return
}
patIdx = commaIdx + 1
}
return isZeroLengthPattern(pattern[patIdx:closingIdx]+pattern[closingIdx+1:], separator)
}
// no luck - validate the rest of the pattern
if !doValidatePattern(pattern, separator) {
return false, ErrBadPattern
}
return false, nil
}
// Finds the next comma, but ignores any commas that appear inside nested `{}`.
// Assumes that each opening bracket has a corresponding closing bracket.
func indexNextAlt(s string, allowEscaping bool) int {
alts := 1
l := len(s)
for i := 0; i < l; i++ {
if allowEscaping && s[i] == '\\' {
// skip next byte
i++
} else if s[i] == '{' {
alts++
} else if s[i] == '}' {
alts--
} else if s[i] == ',' && alts == 1 {
return i
}
}
return -1
}
// Finds the index of the first unescaped byte `c`, or negative 1.
func indexUnescapedByte(s string, c byte, allowEscaping bool) int {
l := len(s)
for i := 0; i < l; i++ {
if allowEscaping && s[i] == '\\' {
// skip next byte
i++
} else if s[i] == c {
return i
}
}
return -1
}
// Assuming the byte before the beginning of `s` is an opening `{`, this
// function will find the index of the matching `}`. That is, it'll skip over
// any nested `{}` and account for escaping
func indexMatchedClosingAlt(s string, allowEscaping bool) int {
alts := 1
l := len(s)
for i := 0; i < l; i++ {
if allowEscaping && s[i] == '\\' {
// skip next byte
i++
} else if s[i] == '{' {
alts++
} else if s[i] == '}' {
if alts--; alts == 0 {
return i
}
}
}
return -1
}
// Validate a pattern. Patterns are validated while they run in Match(),
// PathMatch(), and Glob(), so, you normally wouldn't need to call this.
// However, there are cases where this might be useful: for example, if your
// program allows a user to enter a pattern that you'll run at a later time,
// you might want to validate it.
//
// ValidatePattern assumes your pattern uses '/' as the path separator.
func ValidatePattern(s string) bool {
return doValidatePattern(s, '/')
}
func doValidatePattern(s string, separator rune) bool {
altDepth := 0
l := len(s)
VALIDATE:
for i := 0; i < l; i++ {
switch s[i] {
case '\\':
if separator != '\\' {
// skip the next byte - return false if there is no next byte
if i++; i >= l {
return false
}
}
continue
case '[':
if i++; i >= l {
// class didn't end
return false
}
if s[i] == '^' || s[i] == '!' {
i++
}
if i >= l || s[i] == ']' {
// class didn't end or empty character class
return false
}
for ; i < l; i++ {
if separator != '\\' && s[i] == '\\' {
i++
} else if s[i] == ']' {
// looks good
continue VALIDATE
}
}
// class didn't end
return false
case '{':
altDepth++
continue
case '}':
if altDepth == 0 {
// alt end without a corresponding start
return false
}
altDepth--
continue
}
}
// valid as long as all alts are closed
return altDepth == 0
}