// Copyright (c) 2021 PlanetScale Inc. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package generator import ( "fmt" "runtime/debug" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/runtime/protoimpl" "google.golang.org/protobuf/types/pluginpb" "github.com/planetscale/vtprotobuf/generator/pattern" ) type ObjectSet struct { mp map[string]bool } func NewObjectSet() ObjectSet { return ObjectSet{ mp: map[string]bool{}, } } func (o ObjectSet) String() string { return fmt.Sprintf("%#v", o) } func (o ObjectSet) Contains(g protogen.GoIdent) bool { objectPath := fmt.Sprintf("%s.%s", string(g.GoImportPath), g.GoName) for wildcard := range o.mp { // Ignore malformed pattern error because pattern already checked in Set if ok, _ := pattern.Match(wildcard, objectPath); ok { return true } } return false } func (o ObjectSet) Set(s string) error { if !pattern.ValidatePattern(s) { return pattern.ErrBadPattern } o.mp[s] = true return nil } type Config struct { // Poolable rules determines if pool feature generate for particular message Poolable ObjectSet // PoolableExclude rules determines if pool feature disabled for particular message PoolableExclude ObjectSet Wrap bool WellKnownTypes bool AllowEmpty bool BuildTag string } type Generator struct { plugin *protogen.Plugin cfg *Config features []Feature local map[protoreflect.FullName]bool } const SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) func NewGenerator(plugin *protogen.Plugin, featureNames []string, cfg *Config) (*Generator, error) { plugin.SupportedFeatures = SupportedFeatures features, err := findFeatures(featureNames) if err != nil { return nil, err } local := make(map[protoreflect.FullName]bool) for _, f := range plugin.Files { if f.Generate { local[f.Desc.Package()] = true } } return &Generator{ plugin: plugin, cfg: cfg, features: features, local: local, }, nil } func (gen *Generator) Generate() { for _, file := range gen.plugin.Files { if !file.Generate { continue } var importPath protogen.GoImportPath if !gen.cfg.Wrap { importPath = file.GoImportPath } gf := gen.plugin.NewGeneratedFile(file.GeneratedFilenamePrefix+"_vtproto.pb.go", importPath) gen.generateFile(gf, file) } } func (gen *Generator) generateFile(gf *protogen.GeneratedFile, file *protogen.File) { p := &GeneratedFile{ GeneratedFile: gf, Config: gen.cfg, LocalPackages: gen.local, } if p.Config.BuildTag != "" { // Support both forms of tags for maximum compatibility p.P("//go:build ", p.Config.BuildTag) p.P("// +build ", p.Config.BuildTag) } p.P("// Code generated by protoc-gen-go-vtproto. DO NOT EDIT.") if bi, ok := debug.ReadBuildInfo(); ok { p.P("// protoc-gen-go-vtproto version: ", bi.Main.Version) } p.P("// source: ", file.Desc.Path()) p.P() p.P("package ", file.GoPackageName) p.P() protoimplPackage := protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl") p.P("const (") p.P("// Verify that this generated code is sufficiently up-to-date.") p.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")") p.P("// Verify that runtime/protoimpl is sufficiently up-to-date.") p.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")") p.P(")") p.P() if p.Wrapper() { for _, msg := range file.Messages { p.P(`type `, msg.GoIdent.GoName, ` `, msg.GoIdent) for _, one := range msg.Oneofs { for _, field := range one.Fields { p.P(`type `, field.GoIdent.GoName, ` `, field.GoIdent) } } } } var generated bool for _, feat := range gen.features { featGenerator := feat(p) if featGenerator.GenerateFile(file) { generated = true } } if !generated && !gen.cfg.AllowEmpty { gf.Skip() } }