Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metadata: fix validation issues #6001

Merged
merged 8 commits into from Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 30 additions & 19 deletions internal/metadata/metadata.go
Expand Up @@ -84,25 +84,8 @@ func Set(addr resolver.Address, md metadata.MD) resolver.Address {
// - otherwise, the header value must contain one or more characters from the set [%x20-%x7E].
func Validate(md metadata.MD) error {
for k, vals := range md {
// pseudo-header will be ignored
if k[0] == ':' {
continue
}
// check key, for i that saving a conversion if not using for range
for i := 0; i < len(k); i++ {
r := k[i]
if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && r != '.' && r != '-' && r != '_' {
return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", k)
}
}
if strings.HasSuffix(k, "-bin") {
continue
}
// check value
for _, val := range vals {
if hasNotPrintable(val) {
return fmt.Errorf("header key %q contains value with non-printable ASCII characters", k)
}
if err := ValidatePair(k, vals...); err != nil {
return err
}
}
return nil
Expand All @@ -118,3 +101,31 @@ func hasNotPrintable(msg string) bool {
}
return false
}

// ValidatePair validate single pair in metadata
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I know these comments are similar to ones on master here and throughout the PR, but can you please reword to have proper grammar (capitalization and periods). Also, some of these comments are not explaining why/reasoning a block of code is happening, simply stating exactly the operations taking place within the code block. Please reword to proper grammar/explaining why/delete those which you find appropriate.

func ValidatePair(key string, vals ...string) error {
if key == "" {
return fmt.Errorf("there is an empty key in the header")
}
// pseudo-header will be ignored
if key[0] == ':' {
return nil
}
// check key, for i that saving a conversion if not using for range
for i := 0; i < len(key); i++ {
r := key[i]
if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && r != '.' && r != '-' && r != '_' {
return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", key)
}
}
if strings.HasSuffix(key, "-bin") {
return nil
}
// check value
for _, val := range vals {
if hasNotPrintable(val) {
return fmt.Errorf("header key %q contains value with non-printable ASCII characters", key)
}
}
return nil
}
4 changes: 4 additions & 0 deletions internal/metadata/metadata_test.go
Expand Up @@ -100,6 +100,10 @@ func TestValidate(t *testing.T) {
md: map[string][]string{"test": {string(rune(0x19))}},
want: errors.New("header key \"test\" contains value with non-printable ASCII characters"),
},
{
md: map[string][]string{"": {"valid"}},
want: errors.New("there is an empty key in the header"),
},
{
md: map[string][]string{"test-bin": {string(rune(0x19))}},
want: nil,
Expand Down
11 changes: 10 additions & 1 deletion stream.go
Expand Up @@ -168,10 +168,19 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}

func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
if md, _, ok := metadata.FromOutgoingContextRaw(ctx); ok {
if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok {
// validate md
if err := imetadata.Validate(md); err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
// validate added
for _, kvs := range added {
for i := 0; i < len(kvs); i += 2 {
if err := imetadata.ValidatePair(kvs[i], kvs[i+1]); err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
}
}
}
if channelz.IsOn() {
cc.incrCallsStarted()
Expand Down
91 changes: 62 additions & 29 deletions test/metadata_test.go
Expand Up @@ -36,29 +36,55 @@ import (
)

func (s) TestInvalidMetadata(t *testing.T) {
grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 2)
grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 5)

tests := []struct {
md metadata.MD
want error
recv error
name string
md metadata.MD
appendMD []string
want error
recv error
}{
{
name: "invalid key",
md: map[string][]string{string(rune(0x19)): {"testVal"}},
want: status.Error(codes.Internal, "header key \"\\x19\" contains illegal characters not in [0-9a-z-_.]"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "invalid value",
md: map[string][]string{"test": {string(rune(0x19))}},
want: status.Error(codes.Internal, "header key \"test\" contains value with non-printable ASCII characters"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "invalid appended value",
md: map[string][]string{"test": {"test"}},
appendMD: []string{"/", "value"},
want: status.Error(codes.Internal, "header key \"/\" contains illegal characters not in [0-9a-z-_.]"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "empty appended key",
md: map[string][]string{"test": {"test"}},
appendMD: []string{"", "value"},
want: status.Error(codes.Internal, "there is an empty key in the header"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "empty key",
md: map[string][]string{"": {"test"}},
want: status.Error(codes.Internal, "there is an empty key in the header"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "-bin key with arbitrary value",
md: map[string][]string{"test-bin": {string(rune(0x19))}},
want: nil,
recv: io.EOF,
},
{
name: "valid key and value",
md: map[string][]string{"test": {"value"}},
want: nil,
recv: io.EOF,
Expand All @@ -77,13 +103,16 @@ func (s) TestInvalidMetadata(t *testing.T) {
}
test := tests[testNum]
testNum++
if err := stream.SetHeader(test.md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
// merge original md and added md.
md := metadata.Join(test.md, metadata.Pairs(test.appendMD...))

if err := stream.SetHeader(md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", md, err, test.want)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As below, please reword.

}
if err := stream.SendHeader(test.md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
if err := stream.SendHeader(md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", md, err, test.want)
}
stream.SetTrailer(test.md)
stream.SetTrailer(md)
return nil
},
}
Expand All @@ -93,29 +122,33 @@ func (s) TestInvalidMetadata(t *testing.T) {
defer ss.Stop()

for _, test := range tests {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

ctx = metadata.NewOutgoingContext(ctx, test.md)
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) {
t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
t.Run("unary "+test.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ctx = metadata.NewOutgoingContext(ctx, test.md)
ctx = metadata.AppendToOutgoingContext(ctx, test.appendMD...)
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) {
t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
})
}

// call the stream server's api to drive the server-side unit testing
for _, test := range tests {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
stream, err := ss.Client.FullDuplexCall(ctx)
defer cancel()
if err != nil {
t.Errorf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err)
continue
}
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err)
}
if _, err := stream.Recv(); status.Code(err) != status.Code(test.recv) || !strings.Contains(err.Error(), test.recv.Error()) {
t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv)
}
t.Run("streaming "+test.name, func(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional nit (not important for linter or correctness): I prefer spaces in between operators i.e. "streaming " + test.name). Here and elsewhere

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Errorf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you didn't write this, but this grammar is strange, and is incorrect (it's not being called with context.Background(), it's being called with context with a timeout). Please reword (see other examples in codebase). Perhaps something like "ss.Client.FullDuplexCall(ctx) want err: %v, got err : %v", nil, err). Here and elsewhere.

return
}
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here (see comment above). This isn't sending nil, but an allocated testpb.StreamingOutputCallRequest.

}
if _, err := stream.Recv(); status.Code(err) != status.Code(test.recv) || !strings.Contains(err.Error(), test.recv.Error()) {
t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*got err

}
})
}
}