mirror of
				https://gitea.com/Lydanne/buildx.git
				synced 2025-10-31 16:13:45 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			472 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			472 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /*
 | |
|    Copyright The containerd Authors.
 | |
| 
 | |
|    Licensed under the Apache License, Version 2.0 (the "License");
 | |
|    you may not use this file except in compliance with the License.
 | |
|    You may obtain a copy of the License at
 | |
| 
 | |
|        http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
|    Unless required by applicable law or agreed to in writing, software
 | |
|    distributed under the License is distributed on an "AS IS" BASIS,
 | |
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|    See the License for the specific language governing permissions and
 | |
|    limitations under the License.
 | |
| */
 | |
| 
 | |
| package ttrpc
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"io"
 | |
| 	"math/rand"
 | |
| 	"net"
 | |
| 	"sync"
 | |
| 	"sync/atomic"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/pkg/errors"
 | |
| 	"github.com/sirupsen/logrus"
 | |
| 	"google.golang.org/grpc/codes"
 | |
| 	"google.golang.org/grpc/status"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	ErrServerClosed = errors.New("ttrpc: server closed")
 | |
| )
 | |
| 
 | |
| type Server struct {
 | |
| 	config   *serverConfig
 | |
| 	services *serviceSet
 | |
| 	codec    codec
 | |
| 
 | |
| 	mu          sync.Mutex
 | |
| 	listeners   map[net.Listener]struct{}
 | |
| 	connections map[*serverConn]struct{} // all connections to current state
 | |
| 	done        chan struct{}            // marks point at which we stop serving requests
 | |
| }
 | |
| 
 | |
| func NewServer(opts ...ServerOpt) (*Server, error) {
 | |
| 	config := &serverConfig{}
 | |
| 	for _, opt := range opts {
 | |
| 		if err := opt(config); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return &Server{
 | |
| 		config:      config,
 | |
| 		services:    newServiceSet(),
 | |
| 		done:        make(chan struct{}),
 | |
| 		listeners:   make(map[net.Listener]struct{}),
 | |
| 		connections: make(map[*serverConn]struct{}),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (s *Server) Register(name string, methods map[string]Method) {
 | |
| 	s.services.register(name, methods)
 | |
| }
 | |
| 
 | |
| func (s *Server) Serve(ctx context.Context, l net.Listener) error {
 | |
| 	s.addListener(l)
 | |
| 	defer s.closeListener(l)
 | |
| 
 | |
| 	var (
 | |
| 		backoff    time.Duration
 | |
| 		handshaker = s.config.handshaker
 | |
| 	)
 | |
| 
 | |
| 	if handshaker == nil {
 | |
| 		handshaker = handshakerFunc(noopHandshake)
 | |
| 	}
 | |
| 
 | |
| 	for {
 | |
| 		conn, err := l.Accept()
 | |
| 		if err != nil {
 | |
| 			select {
 | |
| 			case <-s.done:
 | |
| 				return ErrServerClosed
 | |
| 			default:
 | |
| 			}
 | |
| 
 | |
| 			if terr, ok := err.(interface {
 | |
| 				Temporary() bool
 | |
| 			}); ok && terr.Temporary() {
 | |
| 				if backoff == 0 {
 | |
| 					backoff = time.Millisecond
 | |
| 				} else {
 | |
| 					backoff *= 2
 | |
| 				}
 | |
| 
 | |
| 				if max := time.Second; backoff > max {
 | |
| 					backoff = max
 | |
| 				}
 | |
| 
 | |
| 				sleep := time.Duration(rand.Int63n(int64(backoff)))
 | |
| 				logrus.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
 | |
| 				time.Sleep(sleep)
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		backoff = 0
 | |
| 
 | |
| 		approved, handshake, err := handshaker.Handshake(ctx, conn)
 | |
| 		if err != nil {
 | |
| 			logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake")
 | |
| 			conn.Close()
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		sc := s.newConn(approved, handshake)
 | |
| 		go sc.run(ctx)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (s *Server) Shutdown(ctx context.Context) error {
 | |
| 	s.mu.Lock()
 | |
| 	select {
 | |
| 	case <-s.done:
 | |
| 	default:
 | |
| 		// protected by mutex
 | |
| 		close(s.done)
 | |
| 	}
 | |
| 	lnerr := s.closeListeners()
 | |
| 	s.mu.Unlock()
 | |
| 
 | |
| 	ticker := time.NewTicker(200 * time.Millisecond)
 | |
| 	defer ticker.Stop()
 | |
| 	for {
 | |
| 		if s.closeIdleConns() {
 | |
| 			return lnerr
 | |
| 		}
 | |
| 		select {
 | |
| 		case <-ctx.Done():
 | |
| 			return ctx.Err()
 | |
| 		case <-ticker.C:
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Close the server without waiting for active connections.
 | |
| func (s *Server) Close() error {
 | |
| 	s.mu.Lock()
 | |
| 	defer s.mu.Unlock()
 | |
| 
 | |
| 	select {
 | |
| 	case <-s.done:
 | |
| 	default:
 | |
| 		// protected by mutex
 | |
| 		close(s.done)
 | |
| 	}
 | |
| 
 | |
| 	err := s.closeListeners()
 | |
| 	for c := range s.connections {
 | |
| 		c.close()
 | |
| 		delete(s.connections, c)
 | |
| 	}
 | |
| 
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (s *Server) addListener(l net.Listener) {
 | |
| 	s.mu.Lock()
 | |
| 	defer s.mu.Unlock()
 | |
| 	s.listeners[l] = struct{}{}
 | |
| }
 | |
| 
 | |
| func (s *Server) closeListener(l net.Listener) error {
 | |
| 	s.mu.Lock()
 | |
| 	defer s.mu.Unlock()
 | |
| 
 | |
| 	return s.closeListenerLocked(l)
 | |
| }
 | |
| 
 | |
| func (s *Server) closeListenerLocked(l net.Listener) error {
 | |
| 	defer delete(s.listeners, l)
 | |
| 	return l.Close()
 | |
| }
 | |
| 
 | |
| func (s *Server) closeListeners() error {
 | |
| 	var err error
 | |
| 	for l := range s.listeners {
 | |
| 		if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
 | |
| 			err = cerr
 | |
| 		}
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (s *Server) addConnection(c *serverConn) {
 | |
| 	s.mu.Lock()
 | |
| 	defer s.mu.Unlock()
 | |
| 
 | |
| 	s.connections[c] = struct{}{}
 | |
| }
 | |
| 
 | |
| func (s *Server) closeIdleConns() bool {
 | |
| 	s.mu.Lock()
 | |
| 	defer s.mu.Unlock()
 | |
| 	quiescent := true
 | |
| 	for c := range s.connections {
 | |
| 		st, ok := c.getState()
 | |
| 		if !ok || st != connStateIdle {
 | |
| 			quiescent = false
 | |
| 			continue
 | |
| 		}
 | |
| 		c.close()
 | |
| 		delete(s.connections, c)
 | |
| 	}
 | |
| 	return quiescent
 | |
| }
 | |
| 
 | |
| type connState int
 | |
| 
 | |
| const (
 | |
| 	connStateActive = iota + 1 // outstanding requests
 | |
| 	connStateIdle              // no requests
 | |
| 	connStateClosed            // closed connection
 | |
| )
 | |
| 
 | |
| func (cs connState) String() string {
 | |
| 	switch cs {
 | |
| 	case connStateActive:
 | |
| 		return "active"
 | |
| 	case connStateIdle:
 | |
| 		return "idle"
 | |
| 	case connStateClosed:
 | |
| 		return "closed"
 | |
| 	default:
 | |
| 		return "unknown"
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
 | |
| 	c := &serverConn{
 | |
| 		server:    s,
 | |
| 		conn:      conn,
 | |
| 		handshake: handshake,
 | |
| 		shutdown:  make(chan struct{}),
 | |
| 	}
 | |
| 	c.setState(connStateIdle)
 | |
| 	s.addConnection(c)
 | |
| 	return c
 | |
| }
 | |
| 
 | |
| type serverConn struct {
 | |
| 	server    *Server
 | |
| 	conn      net.Conn
 | |
| 	handshake interface{} // data from handshake, not used for now
 | |
| 	state     atomic.Value
 | |
| 
 | |
| 	shutdownOnce sync.Once
 | |
| 	shutdown     chan struct{} // forced shutdown, used by close
 | |
| }
 | |
| 
 | |
| func (c *serverConn) getState() (connState, bool) {
 | |
| 	cs, ok := c.state.Load().(connState)
 | |
| 	return cs, ok
 | |
| }
 | |
| 
 | |
| func (c *serverConn) setState(newstate connState) {
 | |
| 	c.state.Store(newstate)
 | |
| }
 | |
| 
 | |
| func (c *serverConn) close() error {
 | |
| 	c.shutdownOnce.Do(func() {
 | |
| 		close(c.shutdown)
 | |
| 	})
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *serverConn) run(sctx context.Context) {
 | |
| 	type (
 | |
| 		request struct {
 | |
| 			id  uint32
 | |
| 			req *Request
 | |
| 		}
 | |
| 
 | |
| 		response struct {
 | |
| 			id   uint32
 | |
| 			resp *Response
 | |
| 		}
 | |
| 	)
 | |
| 
 | |
| 	var (
 | |
| 		ch          = newChannel(c.conn)
 | |
| 		ctx, cancel = context.WithCancel(sctx)
 | |
| 		active      int
 | |
| 		state       connState = connStateIdle
 | |
| 		responses             = make(chan response)
 | |
| 		requests              = make(chan request)
 | |
| 		recvErr               = make(chan error, 1)
 | |
| 		shutdown              = c.shutdown
 | |
| 		done                  = make(chan struct{})
 | |
| 	)
 | |
| 
 | |
| 	defer c.conn.Close()
 | |
| 	defer cancel()
 | |
| 	defer close(done)
 | |
| 
 | |
| 	go func(recvErr chan error) {
 | |
| 		defer close(recvErr)
 | |
| 		sendImmediate := func(id uint32, st *status.Status) bool {
 | |
| 			select {
 | |
| 			case responses <- response{
 | |
| 				// even though we've had an invalid stream id, we send it
 | |
| 				// back on the same stream id so the client knows which
 | |
| 				// stream id was bad.
 | |
| 				id: id,
 | |
| 				resp: &Response{
 | |
| 					Status: st.Proto(),
 | |
| 				},
 | |
| 			}:
 | |
| 				return true
 | |
| 			case <-c.shutdown:
 | |
| 				return false
 | |
| 			case <-done:
 | |
| 				return false
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		for {
 | |
| 			select {
 | |
| 			case <-c.shutdown:
 | |
| 				return
 | |
| 			case <-done:
 | |
| 				return
 | |
| 			default: // proceed
 | |
| 			}
 | |
| 
 | |
| 			mh, p, err := ch.recv(ctx)
 | |
| 			if err != nil {
 | |
| 				status, ok := status.FromError(err)
 | |
| 				if !ok {
 | |
| 					recvErr <- err
 | |
| 					return
 | |
| 				}
 | |
| 
 | |
| 				// in this case, we send an error for that particular message
 | |
| 				// when the status is defined.
 | |
| 				if !sendImmediate(mh.StreamID, status) {
 | |
| 					return
 | |
| 				}
 | |
| 
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			if mh.Type != messageTypeRequest {
 | |
| 				// we must ignore this for future compat.
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			var req Request
 | |
| 			if err := c.server.codec.Unmarshal(p, &req); err != nil {
 | |
| 				ch.putmbuf(p)
 | |
| 				if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
 | |
| 					return
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 			ch.putmbuf(p)
 | |
| 
 | |
| 			if mh.StreamID%2 != 1 {
 | |
| 				// enforce odd client initiated identifiers.
 | |
| 				if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
 | |
| 					return
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			// Forward the request to the main loop. We don't wait on s.done
 | |
| 			// because we have already accepted the client request.
 | |
| 			select {
 | |
| 			case requests <- request{
 | |
| 				id:  mh.StreamID,
 | |
| 				req: &req,
 | |
| 			}:
 | |
| 			case <-done:
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}(recvErr)
 | |
| 
 | |
| 	for {
 | |
| 		newstate := state
 | |
| 		switch {
 | |
| 		case active > 0:
 | |
| 			newstate = connStateActive
 | |
| 			shutdown = nil
 | |
| 		case active == 0:
 | |
| 			newstate = connStateIdle
 | |
| 			shutdown = c.shutdown // only enable this branch in idle mode
 | |
| 		}
 | |
| 
 | |
| 		if newstate != state {
 | |
| 			c.setState(newstate)
 | |
| 			state = newstate
 | |
| 		}
 | |
| 
 | |
| 		select {
 | |
| 		case request := <-requests:
 | |
| 			active++
 | |
| 			go func(id uint32) {
 | |
| 				ctx, cancel := getRequestContext(ctx, request.req)
 | |
| 				defer cancel()
 | |
| 
 | |
| 				p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
 | |
| 				resp := &Response{
 | |
| 					Status:  status.Proto(),
 | |
| 					Payload: p,
 | |
| 				}
 | |
| 
 | |
| 				select {
 | |
| 				case responses <- response{
 | |
| 					id:   id,
 | |
| 					resp: resp,
 | |
| 				}:
 | |
| 				case <-done:
 | |
| 				}
 | |
| 			}(request.id)
 | |
| 		case response := <-responses:
 | |
| 			p, err := c.server.codec.Marshal(response.resp)
 | |
| 			if err != nil {
 | |
| 				logrus.WithError(err).Error("failed marshaling response")
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 			if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
 | |
| 				logrus.WithError(err).Error("failed sending message on channel")
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 			active--
 | |
| 		case err := <-recvErr:
 | |
| 			// TODO(stevvooe): Not wildly clear what we should do in this
 | |
| 			// branch. Basically, it means that we are no longer receiving
 | |
| 			// requests due to a terminal error.
 | |
| 			recvErr = nil // connection is now "closing"
 | |
| 			if err != nil && err != io.EOF {
 | |
| 				logrus.WithError(err).Error("error receiving message")
 | |
| 			}
 | |
| 		case <-shutdown:
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var noopFunc = func() {}
 | |
| 
 | |
| func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
 | |
| 	cancel = noopFunc
 | |
| 	if req.TimeoutNano == 0 {
 | |
| 		return ctx, cancel
 | |
| 	}
 | |
| 
 | |
| 	ctx, cancel = context.WithTimeout(ctx, time.Duration(req.TimeoutNano))
 | |
| 	return ctx, cancel
 | |
| }
 | 
