Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2123,6 +2123,27 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}

// unwrapServerTransportStream traverses wrapper layers to find
// *transport.ServerStream. It uses the Unwrap pattern (similar to errors.Unwrap)
// to traverse through wrapper types that embed ServerTransportStream.
// It tracks visited wrappers to prevent infinite loops from buggy
// implementations that return themselves from Unwrap.
func unwrapServerTransportStream(s ServerTransportStream) *transport.ServerStream {
seen := map[ServerTransportStream]bool{}
for s != nil && !seen[s] {
if ts, ok := s.(*transport.ServerStream); ok {
return ts
}
seen[s] = true
u, ok := s.(interface{ Unwrap() ServerTransportStream })
if !ok {
return nil
}
s = u.Unwrap()
}
return nil
}

// SetSendCompressor sets a compressor for outbound messages from the server.
// It must not be called after any event that causes headers to be sent
// (see ServerStream.SetHeader for the complete list). Provided compressor is
Expand All @@ -2147,8 +2168,8 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
if !ok || stream == nil {
stream := unwrapServerTransportStream(ServerTransportStreamFromContext(ctx))
if stream == nil {
return fmt.Errorf("failed to fetch the stream from the given context")
}

Expand All @@ -2169,8 +2190,8 @@ func SetSendCompressor(ctx context.Context, name string) error {
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
if !ok || stream == nil {
stream := unwrapServerTransportStream(ServerTransportStreamFromContext(ctx))
if stream == nil {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
}

Expand Down
121 changes: 121 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

Expand All @@ -43,6 +44,126 @@ func errorDesc(err error) string {
return err.Error()
}

// mockServerTransportStream is a mock implementing ServerTransportStream for
// testing unwrapServerTransportStream.
type mockServerTransportStream struct{}

func (s *mockServerTransportStream) Method() string { return "" }
func (s *mockServerTransportStream) SetHeader(metadata.MD) error { return nil }
func (s *mockServerTransportStream) SendHeader(metadata.MD) error { return nil }
func (s *mockServerTransportStream) SetTrailer(metadata.MD) error { return nil }

// wrappingStream wraps a ServerTransportStream and implements the Unwrap pattern.
type wrappingStream struct {
ServerTransportStream
}

func (w *wrappingStream) Unwrap() ServerTransportStream {
return w.ServerTransportStream
}

// selfWrappingStream is a buggy wrapper that returns itself from Unwrap(),
// used to test cycle protection.
type selfWrappingStream struct {
ServerTransportStream
}

func (w *selfWrappingStream) Unwrap() ServerTransportStream {
return w
}

func (s) TestUnwrapServerTransportStream(t *testing.T) {
ts := &transport.ServerStream{}

tests := []struct {
name string
s ServerTransportStream
want *transport.ServerStream
}{
{
name: "direct transport.ServerStream",
s: ts,
want: ts,
},
{
name: "single wrapper",
s: &wrappingStream{ServerTransportStream: ts},
want: ts,
},
{
name: "nested wrappers",
s: &wrappingStream{
ServerTransportStream: &wrappingStream{
ServerTransportStream: ts,
},
},
want: ts,
},
{
name: "no Unwrap method",
s: &mockServerTransportStream{},
want: nil,
},
{
name: "nil input",
s: nil,
want: nil,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := unwrapServerTransportStream(test.s)
if got != test.want {
t.Errorf("unwrapServerTransportStream() = %v, want %v", got, test.want)
}
})
}
}

func (s) TestUnwrapServerTransportStreamCycleProtection(t *testing.T) {
got := unwrapServerTransportStream(&selfWrappingStream{})
if got != nil {
t.Errorf("unwrapServerTransportStream() with cycle = %v, want nil", got)
}
}

func (s) TestClientSupportedCompressorsWithWrappedContext(t *testing.T) {
ts := &transport.ServerStream{}
// Put the transport stream into context wrapped by a layer, simulating
// what the OpenTelemetry plugin does.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = NewContextWithServerTransportStream(ctx, &wrappingStream{
ServerTransportStream: ts,
})

compressors, err := ClientSupportedCompressors(ctx)
if err != nil {
t.Fatalf("ClientSupportedCompressors() returned error: %v", err)
}
// A zero-value ServerStream has an empty clientAdvertisedCompressors
// field, so ClientAdvertisedCompressors() returns [""].
if len(compressors) != 1 || compressors[0] != "" {
t.Errorf("ClientSupportedCompressors() = %v, want [\"\"]", compressors)
}
}

func (s) TestSetSendCompressorWithWrappedContext(t *testing.T) {
ts := &transport.ServerStream{}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = NewContextWithServerTransportStream(ctx, &wrappingStream{
ServerTransportStream: ts,
})

// "identity" is always valid (skips compressor registration and client
// advertised check), so it exercises the full unwrap + SetSendCompress path.
if err := SetSendCompressor(ctx, "identity"); err != nil {
t.Fatalf("SetSendCompressor() returned error: %v", err)
}
}

func (s) TestStopBeforeServe(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions stats/opentelemetry/server_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ func (s *attachLabelsTransportStream) SetHeader(md metadata.MD) error {
return s.ServerTransportStream.SetHeader(md)
}

func (s *attachLabelsTransportStream) Unwrap() grpc.ServerTransportStream {
return s.ServerTransportStream
}

func (s *attachLabelsTransportStream) SendHeader(md metadata.MD) error {
if !s.attachedLabels.Swap(true) {
s.ServerTransportStream.SetHeader(s.metadataExchangeLabels)
Expand Down
Loading