Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions go/sqltypes/proto3.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,42 @@ func RowsToProto3(rows [][]Value) []*querypb.Row {
return nil
}

// Batch-allocate all Row structs in a single backing array to reduce
// per-row heap allocations from N to 1.
backing := make([]querypb.Row, len(rows))
result := make([]*querypb.Row, len(rows))

// Pre-allocate a single Lengths backing array for all rows.
nCols := len(rows[0])
allLengths := make([]int64, 0, nCols*len(rows))

// First pass: compute lengths and accumulate total value size.
totalValueBytes := 0
for i, r := range rows {
result[i] = RowToProto3(r)
result[i] = &backing[i]
start := len(allLengths)
for _, c := range r {
if c.IsNull() {
allLengths = append(allLengths, -1)
} else {
l := c.Len()
allLengths = append(allLengths, int64(l))
totalValueBytes += l
}
}
backing[i].Lengths = allLengths[start:]
}

// Second pass: batch-allocate all Values into a single buffer.
allValues := make([]byte, 0, totalValueBytes)
for i, r := range rows {
start := len(allValues)
for _, c := range r {
if !c.IsNull() {
allValues = append(allValues, c.Raw()...)
}
}
backing[i].Values = allValues[start:]
}
Comment on lines +74 to 110
return result
}
Expand All @@ -87,9 +120,26 @@ func proto3ToRows(fields []*querypb.Field, rows []*querypb.Row) [][]Value {
return [][]Value{}
}

result := make([][]Value, len(rows))
nCols := len(fields)
nRows := len(rows)

// Combined allocation: result slice headers + backing Values in one block.
// Layout: [nRows []Value headers] allocated as make([][]Value, nRows)
// then [nRows*nCols Value] backing allocated separately.
// For single row, the make calls may be optimized by the compiler.
backing := make([]Value, nRows*nCols)
result := make([][]Value, nRows)
for i, r := range rows {
result[i] = MakeRowTrusted(fields, r)
sqlRow := backing[i*nCols : (i+1)*nCols]
var offset int64
for j, length := range r.Lengths {
if length < 0 {
continue
}
sqlRow[j] = MakeTrusted(fields[j].Type, r.Values[offset:offset+length])
offset += length
}
result[i] = sqlRow
}
return result
}
Expand Down
18 changes: 16 additions & 2 deletions go/streamlog/streamlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,35 @@ func shouldEmitLogOnCondition(aCond bool, aReason string, allMatches bool, reaso
}
}

// HasSubscribers returns true if there are any active subscribers.
// This can be used to skip expensive work (e.g., copying bind variables)
// when no one is listening.
func (logger *StreamLogger[T]) HasSubscribers() bool {
logger.mu.Lock()
has := len(logger.subscribed) > 0
logger.mu.Unlock()
return has
}

// Send sends message to all the writers subscribed to logger. Calling
// Send does not block.
func (logger *StreamLogger[T]) Send(message T) {
// Send does not block. It returns the number of subscribers the message
// was delivered to.
func (logger *StreamLogger[T]) Send(message T) int {
logger.mu.Lock()
defer logger.mu.Unlock()

delivered := 0
for ch, name := range logger.subscribed {
select {
case ch <- message:
deliveredCount.Add([]string{logger.name, name}, 1)
delivered++
default:
deliveryDropCount.Add([]string{logger.name, name}, 1)
}
}
sendCount.Add(logger.name, 1)
return delivered
}

// Subscribe returns a channel which can be used to listen
Expand Down
7 changes: 7 additions & 0 deletions go/trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ func AnnotateSQL(span Span, strippedSQL fmt.Stringer) {
span.Annotate("sql-statement-type", strippedSQL.String())
}

// IsNoop returns true when the current tracer is the noop tracer.
// Callers can use this to avoid computing trace annotations that would be discarded.
func IsNoop() bool {
_, ok := currentTracer.(noopTracingServer)
return ok
}

// FromContext returns the Span from a Context if present. The bool return
// value indicates whether a Span was present in the Context.
func FromContext(ctx context.Context) (Span, bool) {
Expand Down
13 changes: 7 additions & 6 deletions go/vt/callerid/callerid.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ func GetSubcomponent(ef *vtrpcpb.CallerID) string {
}

// NewContext adds the provided EffectiveCallerID(vtrpcpb.CallerID) and ImmediateCallerID(querypb.VTGateCallerID)
// into the Context
// into the Context. Skips context wrapping for nil values to avoid unnecessary allocations.
func NewContext(ctx context.Context, ef *vtrpcpb.CallerID, im *querypb.VTGateCallerID) context.Context {
ctx = context.WithValue(
context.WithValue(ctx, effectiveCallerIDKey, ef),
immediateCallerIDKey,
im,
)
if ef != nil {
ctx = context.WithValue(ctx, effectiveCallerIDKey, ef)
}
if im != nil {
ctx = context.WithValue(ctx, immediateCallerIDKey, im)
}
return ctx
}

Expand Down
50 changes: 33 additions & 17 deletions go/vt/callinfo/plugin_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package callinfo
import (
"context"
"fmt"
"net"

"github.com/google/safehtml"
"github.com/google/safehtml/template"
Expand All @@ -29,50 +30,65 @@ import (
)

// GRPCCallInfo returns an augmented context with a CallInfo structure,
// only for gRPC contexts.
// only for gRPC contexts. Uses a combined callInfoContext to avoid
// separate allocations for the struct and context.WithValue wrapper.
func GRPCCallInfo(ctx context.Context) context.Context {
method, ok := grpc.Method(ctx)
if !ok {
return ctx
}

callinfo := &gRPCCallInfoImpl{
method: method,
c := &callInfoContext{
Context: ctx,
method: method,
}
peer, ok := peer.FromContext(ctx)
if ok {
callinfo.remoteAddr = peer.Addr.String()
if p, ok := peer.FromContext(ctx); ok {
c.remoteAddr = p.Addr
}

return NewContext(ctx, callinfo)
return c
}

type gRPCCallInfoImpl struct {
// callInfoContext is a combined context+callinfo that avoids separate allocations
// for the gRPCCallInfoImpl struct and the context.WithValue wrapper. It embeds
// the parent context and implements Value() to return itself for the callInfoKey.
type callInfoContext struct {
context.Context
method string
remoteAddr string
remoteAddr net.Addr
}

func (gci *gRPCCallInfoImpl) RemoteAddr() string {
return gci.remoteAddr
func (c *callInfoContext) Value(key any) any {
if key == callInfoKey {
return CallInfo(c)
}
return c.Context.Value(key)
}

func (c *callInfoContext) RemoteAddr() string {
if c.remoteAddr == nil {
return ""
}
return c.remoteAddr.String()
}

func (gci *gRPCCallInfoImpl) Username() string {
func (c *callInfoContext) Username() string {
return "gRPC"
}

func (gci *gRPCCallInfoImpl) Text() string {
return fmt.Sprintf("%s:%s(gRPC)", gci.remoteAddr, gci.method)
func (c *callInfoContext) Text() string {
return fmt.Sprintf("%s:%s(gRPC)", c.RemoteAddr(), c.method)
}

var grpcTmpl = template.Must(template.New("tcs").Parse("<b>Method:</b> {{.Method}} <b>Remote Addr:</b> {{.RemoteAddr}}"))

func (gci *gRPCCallInfoImpl) HTML() safehtml.HTML {
func (c *callInfoContext) HTML() safehtml.HTML {
html, err := grpcTmpl.ExecuteToHTML(struct {
Method string
RemoteAddr string
}{
Method: gci.method,
RemoteAddr: gci.remoteAddr,
Method: c.method,
RemoteAddr: c.RemoteAddr(),
})
if err != nil {
panic(err)
Expand Down
31 changes: 26 additions & 5 deletions go/vt/callinfo/plugin_grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,41 @@ limitations under the License.
package callinfo

import (
"net"
"testing"

"github.com/stretchr/testify/require"
)

func TestGRPCCallInfo(t *testing.T) {
grpcCi := gRPCCallInfoImpl{
grpcCi := callInfoContext{
Context: t.Context(),
method: "tcp",
remoteAddr: "localhost",
remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
}

require.Equal(t, t.Context(), GRPCCallInfo(t.Context()))
require.Equal(t, grpcCi.remoteAddr, grpcCi.RemoteAddr())
require.Equal(t, "127.0.0.1:8080", grpcCi.RemoteAddr())
require.Equal(t, "gRPC", grpcCi.Username())
require.Equal(t, "localhost:tcp(gRPC)", grpcCi.Text())
require.Equal(t, "<b>Method:</b> tcp <b>Remote Addr:</b> localhost", grpcCi.HTML().String())
require.Equal(t, "127.0.0.1:8080:tcp(gRPC)", grpcCi.Text())
require.Equal(t, "<b>Method:</b> tcp <b>Remote Addr:</b> 127.0.0.1:8080", grpcCi.HTML().String())
}

func TestGRPCCallInfoNilAddr(t *testing.T) {
grpcCi := callInfoContext{
Context: t.Context(),
method: "test",
}
require.Equal(t, "", grpcCi.RemoteAddr())
}

func TestGRPCCallInfoFromContext(t *testing.T) {
ctx := &callInfoContext{
Context: t.Context(),
method: "/queryservice.Query/Execute",
remoteAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
}
ci, ok := FromContext(ctx)
require.True(t, ok)
require.Equal(t, "10.0.0.1:1234", ci.RemoteAddr())
}
18 changes: 13 additions & 5 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type (
inDerived int
inSelect int

bindVarNeeds *BindVarNeeds
bindVarNeeds BindVarNeeds
shouldRewriteDatabaseFunc bool
hasStarInSelect bool

Expand Down Expand Up @@ -126,7 +126,7 @@ func Normalize(

return &RewriteASTResult{
AST: out.(Statement),
BindVarNeeds: nz.bindVarNeeds,
BindVarNeeds: &nz.bindVarNeeds,
UpdateQueryFromAST: nz.useASTQuery,
}, nil
}
Expand All @@ -146,15 +146,12 @@ func newNormalizer(
bindVars: bindVars,
reserved: reserved,
vals: make(map[Literal]string),
tupleVals: make(map[string]string),
bindVarNeeds: &BindVarNeeds{},
keyspace: keyspace,
selectLimit: selectLimit,
setVarComment: setVarComment,
fkChecksState: fkChecksState,
sysVars: sysVars,
views: views,
onLeave: make(map[*AliasedExpr]func(*AliasedExpr)),
parameterize: parameterize,
}
}
Expand Down Expand Up @@ -220,9 +217,17 @@ func (nz *normalizer) noteAliasedExprName(node *AliasedExpr) {
if node.As.NotEmpty() {
return
}
// Column references are never rewritten by normalization (only literals are),
// so there's no need to track them for alias preservation.
if _, ok := node.Expr.(*ColName); ok {
return
}
buf := NewTrackedBuffer(nil)
node.Expr.Format(buf)
rewrites := nz.bindVarNeeds.NumberOfRewrites()
if nz.onLeave == nil {
nz.onLeave = make(map[*AliasedExpr]func(*AliasedExpr))
}
nz.onLeave[node] = func(newAliasedExpr *AliasedExpr) {
if nz.bindVarNeeds.NumberOfRewrites() > rewrites {
newAliasedExpr.As = NewIdentifierCI(buf.String())
Expand Down Expand Up @@ -485,6 +490,9 @@ func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) {
}

nz.bindVars[bvname] = bvals
if nz.tupleVals == nil {
nz.tupleVals = make(map[string]string)
}
nz.tupleVals[string(key)] = bvname
}

Expand Down
21 changes: 19 additions & 2 deletions go/vt/sqlparser/tracked_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,19 @@ import (
"reflect"
"strconv"
"strings"
"sync"

"vitess.io/vitess/go/slice"
)

var trackedBufferPool = sync.Pool{
New: func() any {
return &TrackedBuffer{
Builder: new(strings.Builder),
}
},
}

// NodeFormatter defines the signature of a custom node formatter
// function that can be given to TrackedBuffer for code generation.
type NodeFormatter func(buf *TrackedBuffer, node SQLNode)
Expand Down Expand Up @@ -398,9 +407,17 @@ func String(node SQLNode) string {
return "<nil>"
}

buf := NewTrackedBuffer(nil)
buf := trackedBufferPool.Get().(*TrackedBuffer)
buf.Builder.Reset()
buf.bindLocations = buf.bindLocations[:0]
buf.nodeFormatter = nil
buf.literal = buf.WriteString
buf.fast = true
buf.escape = escapeKeywords
node.FormatFast(buf)
return buf.String()
s := buf.String()
trackedBufferPool.Put(buf)
return s
}

// UnescapedString will return a string where no identifiers have been escaped.
Expand Down
Loading