Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix preloader mode in benchmarks #6359

Merged
merged 5 commits into from Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 54 additions & 34 deletions benchmark/benchmain/main.go
Expand Up @@ -53,6 +53,7 @@ import (
"reflect"
"runtime"
"runtime/pprof"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -81,7 +82,8 @@ var (
traceMode = flags.StringWithAllowedValues("trace", toggleModeOff,
fmt.Sprintf("Trace mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes)
preloaderMode = flags.StringWithAllowedValues("preloader", toggleModeOff,
fmt.Sprintf("Preloader mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes)
fmt.Sprintf("Preloader mode - One of: %v, preloader works only in streaming and unconstrained modes and will be ignored in unary mode",
strings.Join(allToggleModes, ", ")), allToggleModes)
channelzOn = flags.StringWithAllowedValues("channelz", toggleModeOff,
fmt.Sprintf("Channelz mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes)
compressorMode = flags.StringWithAllowedValues("compression", compModeOff,
Expand Down Expand Up @@ -401,20 +403,11 @@ func makeFuncUnary(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) {
}

func makeFuncStream(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) {
clients, cleanup := makeClients(bf)
streams, req, cleanup := setupStream(bf, false)

streams := make([][]testgrpc.BenchmarkService_StreamingCallClient, bf.Connections)
for cn := 0; cn < bf.Connections; cn++ {
tc := clients[cn]
streams[cn] = make([]testgrpc.BenchmarkService_StreamingCallClient, bf.MaxConcurrentCalls)
for pos := 0; pos < bf.MaxConcurrentCalls; pos++ {

stream, err := tc.StreamingCall(context.Background())
if err != nil {
logger.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
streams[cn][pos] = stream
}
var preparedMsg [][]*grpc.PreparedMsg
if bf.EnablePreloader {
preparedMsg = prepareMessages(streams, req)
}

return func(cn, pos int) {
Expand All @@ -426,24 +419,25 @@ func makeFuncStream(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) {
if bf.RespPayloadCurve != nil {
respSizeBytes = bf.RespPayloadCurve.ChooseRandom()
}
streamCaller(streams[cn][pos], reqSizeBytes, respSizeBytes)
var req interface{}
if bf.EnablePreloader {
req = preparedMsg[cn][pos]
} else {
pl := bm.NewPayload(testpb.PayloadType_COMPRESSABLE, reqSizeBytes)
req = &testpb.SimpleRequest{
ResponseType: pl.Type,
ResponseSize: int32(respSizeBytes),
Payload: pl,
}
}
streamCaller(streams[cn][pos], req)
}, cleanup
}

func makeFuncUnconstrainedStreamPreloaded(bf stats.Features) (rpcSendFunc, rpcRecvFunc, rpcCleanupFunc) {
streams, req, cleanup := setupUnconstrainedStream(bf)
streams, req, cleanup := setupStream(bf, true)

preparedMsg := make([][]*grpc.PreparedMsg, len(streams))
for cn, connStreams := range streams {
preparedMsg[cn] = make([]*grpc.PreparedMsg, len(connStreams))
for pos, stream := range connStreams {
preparedMsg[cn][pos] = &grpc.PreparedMsg{}
err := preparedMsg[cn][pos].Encode(stream, req)
if err != nil {
logger.Fatalf("%v.Encode(%v, %v) = %v", preparedMsg[cn][pos], req, stream, err)
}
}
}
preparedMsg := prepareMessages(streams, req)

return func(cn, pos int) {
streams[cn][pos].SendMsg(preparedMsg[cn][pos])
Expand All @@ -453,7 +447,7 @@ func makeFuncUnconstrainedStreamPreloaded(bf stats.Features) (rpcSendFunc, rpcRe
}

func makeFuncUnconstrainedStream(bf stats.Features) (rpcSendFunc, rpcRecvFunc, rpcCleanupFunc) {
streams, req, cleanup := setupUnconstrainedStream(bf)
streams, req, cleanup := setupStream(bf, true)

return func(cn, pos int) {
streams[cn][pos].Send(req)
Expand All @@ -462,13 +456,19 @@ func makeFuncUnconstrainedStream(bf stats.Features) (rpcSendFunc, rpcRecvFunc, r
}, cleanup
}

func setupUnconstrainedStream(bf stats.Features) ([][]testgrpc.BenchmarkService_StreamingCallClient, *testpb.SimpleRequest, rpcCleanupFunc) {
func setupStream(bf stats.Features, unconstrained bool) ([][]testgrpc.BenchmarkService_StreamingCallClient, *testpb.SimpleRequest, rpcCleanupFunc) {
clients, cleanup := makeClients(bf)

streams := make([][]testgrpc.BenchmarkService_StreamingCallClient, bf.Connections)
md := metadata.Pairs(benchmark.UnconstrainedStreamingHeader, "1",
benchmark.UnconstrainedStreamingDelayHeader, bf.SleepBetweenRPCs.String())
ctx := metadata.NewOutgoingContext(context.Background(), md)
ctx := context.Background()
if unconstrained {
md := metadata.Pairs(benchmark.UnconstrainedStreamingHeader, "1", benchmark.UnconstrainedStreamingDelayHeader, bf.SleepBetweenRPCs.String())
ctx = metadata.NewOutgoingContext(ctx, md)
}
if bf.EnablePreloader {
md := metadata.Pairs(benchmark.PreloadMsgSizeHeader, strconv.Itoa(bf.RespSizeBytes), benchmark.UnconstrainedStreamingDelayHeader, bf.SleepBetweenRPCs.String())
ctx = metadata.NewOutgoingContext(ctx, md)
}
for cn := 0; cn < bf.Connections; cn++ {
tc := clients[cn]
streams[cn] = make([]testgrpc.BenchmarkService_StreamingCallClient, bf.MaxConcurrentCalls)
Expand All @@ -491,6 +491,20 @@ func setupUnconstrainedStream(bf stats.Features) ([][]testgrpc.BenchmarkService_
return streams, req, cleanup
}

func prepareMessages(streams [][]testgrpc.BenchmarkService_StreamingCallClient, req *testpb.SimpleRequest) [][]*grpc.PreparedMsg {
preparedMsg := make([][]*grpc.PreparedMsg, len(streams))
for cn, connStreams := range streams {
preparedMsg[cn] = make([]*grpc.PreparedMsg, len(connStreams))
for pos, stream := range connStreams {
preparedMsg[cn][pos] = &grpc.PreparedMsg{}
if err := preparedMsg[cn][pos].Encode(stream, req); err != nil {
logger.Fatalf("%v.Encode(%v, %v) = %v", preparedMsg[cn][pos], req, stream, err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to return an error from here instead of calling Fatal. Please see:
https://google.github.io/styleguide/go/best-practices#program-checks-and-panics and https://google.github.io/styleguide/go/decisions#dont-panic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a common pattern used in benchmarks logger.Fatal is used 6 times in the same file. If we want to propagate the error up to the main method and handle it there we should fix all the cases, I can do it in the same PR, but this would be an unrelated change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok fair enough. If you have cycles, and would like to fix this in a follow-up PR, I would be happy to review it. Thanks.

}
}
}
return preparedMsg
}

// Makes a UnaryCall gRPC request using the given BenchmarkServiceClient and
// request and response sizes.
func unaryCaller(client testgrpc.BenchmarkServiceClient, reqSize, respSize int) {
Expand All @@ -499,8 +513,8 @@ func unaryCaller(client testgrpc.BenchmarkServiceClient, reqSize, respSize int)
}
}

func streamCaller(stream testgrpc.BenchmarkService_StreamingCallClient, reqSize, respSize int) {
if err := bm.DoStreamingRoundTrip(stream, reqSize, respSize); err != nil {
func streamCaller(stream testgrpc.BenchmarkService_StreamingCallClient, req interface{}) {
if err := bm.DoStreamingRoundTripPreloaded(stream, req); err != nil {
logger.Fatalf("DoStreamingRoundTrip failed: %v", err)
}
}
Expand Down Expand Up @@ -790,6 +804,9 @@ func processFlags() *benchOpts {
if len(opts.features.reqSizeBytes) != 0 {
log.Fatalf("you may not specify -reqPayloadCurveFiles and -reqSizeBytes at the same time")
}
if len(opts.features.enablePreloader) != 0 {
log.Fatalf("you may not specify -reqPayloadCurveFiles and -preloader at the same time")
}
for _, file := range *reqPayloadCurveFiles {
pc, err := stats.NewPayloadCurve(file)
if err != nil {
Expand All @@ -807,6 +824,9 @@ func processFlags() *benchOpts {
if len(opts.features.respSizeBytes) != 0 {
log.Fatalf("you may not specify -respPayloadCurveFiles and -respSizeBytes at the same time")
}
if len(opts.features.enablePreloader) != 0 {
log.Fatalf("you may not specify -respPayloadCurveFiles and -preloader at the same time")
}
for _, file := range *respPayloadCurveFiles {
pc, err := stats.NewPayloadCurve(file)
if err != nil {
Expand Down
58 changes: 52 additions & 6 deletions benchmark/benchmark.go
Expand Up @@ -28,6 +28,7 @@ import (
"log"
"math/rand"
"net"
"strconv"
"time"

"google.golang.org/grpc"
Expand Down Expand Up @@ -83,13 +84,35 @@ const UnconstrainedStreamingHeader = "unconstrained-streaming"
// the server should sleep between consecutive RPC responses.
const UnconstrainedStreamingDelayHeader = "unconstrained-streaming-delay"

// PreloadMsgSizeHeader indicates that the client is going to ask for
// a fixed response size and passes this size to the server.
// The server is expected to preload the response on startup.
const PreloadMsgSizeHeader = "preload-msg-size"

func (s *testServer) StreamingCall(stream testgrpc.BenchmarkService_StreamingCallServer) error {
preloadMsgSize := 0
if md, ok := metadata.FromIncomingContext(stream.Context()); ok && len(md[PreloadMsgSizeHeader]) != 0 {
val := md[PreloadMsgSizeHeader][0]
var err error
preloadMsgSize, err = strconv.Atoi(val)
if err != nil {
return fmt.Errorf("%q header value is not an integer: %s", PreloadMsgSizeHeader, err)
}
}

if md, ok := metadata.FromIncomingContext(stream.Context()); ok && len(md[UnconstrainedStreamingHeader]) != 0 {
return s.UnconstrainedStreamingCall(stream)
return s.UnconstrainedStreamingCall(stream, preloadMsgSize)
}
response := &testpb.SimpleResponse{
Payload: new(testpb.Payload),
}
preloadedResponse := &grpc.PreparedMsg{}
if preloadMsgSize > 0 {
setPayload(response.Payload, testpb.PayloadType_COMPRESSABLE, preloadMsgSize)
if err := preloadedResponse.Encode(stream, response); err != nil {
return err
}
}
in := new(testpb.SimpleRequest)
for {
// use ServerStream directly to reuse the same testpb.SimpleRequest object
Expand All @@ -101,14 +124,19 @@ func (s *testServer) StreamingCall(stream testgrpc.BenchmarkService_StreamingCal
if err != nil {
return err
}
setPayload(response.Payload, in.ResponseType, int(in.ResponseSize))
if err := stream.Send(response); err != nil {
if preloadMsgSize > 0 {
err = stream.SendMsg(preloadedResponse)
} else {
setPayload(response.Payload, in.ResponseType, int(in.ResponseSize))
err = stream.Send(response)
}
if err != nil {
return err
}
}
}

func (s *testServer) UnconstrainedStreamingCall(stream testgrpc.BenchmarkService_StreamingCallServer) error {
func (s *testServer) UnconstrainedStreamingCall(stream testgrpc.BenchmarkService_StreamingCallServer, preloadMsgSize int) error {
maxSleep := 0
if md, ok := metadata.FromIncomingContext(stream.Context()); ok && len(md[UnconstrainedStreamingDelayHeader]) != 0 {
val := md[UnconstrainedStreamingDelayHeader][0]
Expand All @@ -135,6 +163,13 @@ func (s *testServer) UnconstrainedStreamingCall(stream testgrpc.BenchmarkService
}
setPayload(response.Payload, in.ResponseType, int(in.ResponseSize))

preloadedResponse := &grpc.PreparedMsg{}
if preloadMsgSize > 0 {
if err := preloadedResponse.Encode(stream, response); err != nil {
return err
}
}

go func() {
for {
// Using RecvMsg rather than Recv to prevent reallocation of SimpleRequest.
Expand All @@ -154,7 +189,12 @@ func (s *testServer) UnconstrainedStreamingCall(stream testgrpc.BenchmarkService
if maxSleep > 0 {
time.Sleep(time.Duration(rand.Intn(maxSleep)))
}
err := stream.Send(response)
var err error
if preloadMsgSize > 0 {
err = stream.SendMsg(preloadedResponse)
} else {
err = stream.Send(response)
}
switch status.Code(err) {
case codes.Unavailable, codes.Canceled:
return
Expand Down Expand Up @@ -258,7 +298,13 @@ func DoStreamingRoundTrip(stream testgrpc.BenchmarkService_StreamingCallClient,
ResponseSize: int32(respSize),
Payload: pl,
}
if err := stream.Send(req); err != nil {
return DoStreamingRoundTripPreloaded(stream, req)
}

// DoStreamingRoundTripPreloaded performs a round trip for a single streaming rpc with preloaded payload.
func DoStreamingRoundTripPreloaded(stream testgrpc.BenchmarkService_StreamingCallClient, req interface{}) error {
// req could be either *testpb.SimpleRequest or *grpc.PreparedMsg
if err := stream.SendMsg(req); err != nil {
return fmt.Errorf("/BenchmarkService/StreamingCall.Send(_) = %v, want <nil>", err)
}
if _, err := stream.Recv(); err != nil {
Expand Down