Skip to content

Commit

Permalink
GODRIVER-3172 Read response in the background after an op timeout.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdale committed Apr 2, 2024
1 parent d41a7cc commit 22b758d
Show file tree
Hide file tree
Showing 11 changed files with 254 additions and 32 deletions.
6 changes: 4 additions & 2 deletions internal/csot/csot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ type timeoutKey struct{}
// TODO default behavior.
func MakeTimeoutContext(ctx context.Context, to time.Duration) (context.Context, context.CancelFunc) {
// Only use the passed in Duration as a timeout on the Context if it
// is non-zero.
// is non-zero and if the Context doesn't already have a timeout.
cancelFunc := func() {}
if to != 0 {
if _, deadlineSet := ctx.Deadline(); to != 0 && !deadlineSet {
ctx, cancelFunc = context.WithTimeout(ctx, to)
}

// Add timeoutKey either way to indicate CSOT is enabled.
return context.WithValue(ctx, timeoutKey{}, true), cancelFunc
}

Expand Down
2 changes: 1 addition & 1 deletion mongo/change_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err
// If no deadline is set on the passed-in context, cs.client.timeout is set, and context is not already
// a Timeout context, honor cs.client.timeout in new Timeout context for change stream operation execution
// and potential retry.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && cs.client.timeout != nil && !csot.IsTimeoutContext(ctx) {
if cs.client.timeout != nil && !csot.IsTimeoutContext(ctx) {
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *cs.client.timeout)
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
Expand Down
13 changes: 11 additions & 2 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,15 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i
// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/.
func (coll *Collection) Find(ctx context.Context, filter interface{},
opts ...*options.FindOptions) (cur *Cursor, err error) {
return coll.find(ctx, nil, filter, opts...)
}

func (coll *Collection) find(
ctx context.Context,
timeout *time.Duration,
filter interface{},
opts ...*options.FindOptions,
) (cur *Cursor, err error) {

if ctx == nil {
ctx = context.Background()
Expand Down Expand Up @@ -1230,7 +1239,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{},
CommandMonitor(coll.client.monitor).ServerSelector(selector).
ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name).
Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI).
Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger)
Timeout(timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger)

cursorOpts := coll.client.createBaseCursorOptions()

Expand Down Expand Up @@ -1408,7 +1417,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{},
// by the server.
findOpts = append(findOpts, options.Find().SetLimit(-1))

cursor, err := coll.Find(ctx, filter, findOpts...)
cursor, err := coll.find(ctx, coll.client.timeout, filter, findOpts...)
return &SingleResult{
ctx: ctx,
cur: cursor,
Expand Down
4 changes: 2 additions & 2 deletions mongo/gridfs/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func (b *Bucket) DeleteContext(ctx context.Context, fileID interface{}) error {
// If no deadline is set on the passed-in context, Timeout is set on the Client, and context is
// not already a Timeout context, honor Timeout in new Timeout context for operation execution to
// be shared by both delete operations.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && b.db.Client().Timeout() != nil && !csot.IsTimeoutContext(ctx) {
if b.db.Client().Timeout() != nil && !csot.IsTimeoutContext(ctx) {
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *b.db.Client().Timeout())
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
Expand Down Expand Up @@ -387,7 +387,7 @@ func (b *Bucket) DropContext(ctx context.Context) error {
// If no deadline is set on the passed-in context, Timeout is set on the Client, and context is
// not already a Timeout context, honor Timeout in new Timeout context for operation execution to
// be shared by both drop operations.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && b.db.Client().Timeout() != nil && !csot.IsTimeoutContext(ctx) {
if b.db.Client().Timeout() != nil && !csot.IsTimeoutContext(ctx) {
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *b.db.Client().Timeout())
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
Expand Down
136 changes: 136 additions & 0 deletions mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package integration

import (
"context"
"errors"
"fmt"
"net"
"os"
Expand All @@ -19,6 +20,7 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/assert"
Expand Down Expand Up @@ -1006,3 +1008,137 @@ func TestClientStress(t *testing.T) {
}
})
}

func TestCSOT(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))

csotOpts := mtest.NewOptions().ClientOptions(options.Client().SetTimeout(10 * time.Second))
mt.RunOpts("includes maxTimeMS if CSOT timeout is set", csotOpts, func(mt *mtest.T) {
mt.Run("with context.Background", func(mt *mtest.T) {
_, err := mt.Coll.InsertOne(context.Background(), bson.D{})
require.NoError(mt, err, "InsertOne error")

maxTimeVal := mt.GetStartedEvent().Command.Lookup("maxTimeMS")

require.True(mt, len(maxTimeVal.Value) > 0, "expected maxTimeMS BSON value to be non-empty")
require.Equal(mt, maxTimeVal.Type, bsontype.Int64, "expected maxTimeMS value to be type Int64")

maxTimeMS := maxTimeVal.Int64()
assert.True(mt, maxTimeMS > 0, "expected maxTimeMS value to be greater than 0")
})
mt.Run("with context.WithTimeout", func(mt *mtest.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()

_, err := mt.Coll.InsertOne(ctx, bson.D{})
require.NoError(mt, err, "InsertOne error")

maxTimeVal := mt.GetStartedEvent().Command.Lookup("maxTimeMS")
require.True(mt, len(maxTimeVal.Value) > 0, "expected maxTimeMS BSON value to be non-empty")
require.Equal(mt, maxTimeVal.Type, bsontype.Int64, "expected maxTimeMS value to be type Int64")

maxTimeMS := maxTimeVal.Int64()
assert.True(mt,
maxTimeMS > 60_000,
"expected maxTimeMS value to be greater than 60000, but got %v",
maxTimeMS)
})
})

mt.RunOpts("timeout errors wrap context.DeadlineExceeded", csotOpts, func(mt *mtest.T) {
// Test that a client-side timeout is a context.DeadlineExceeded
mt.Run("MaxTimeMSExceeded", func(mt *mtest.T) {
_, err := mt.Coll.InsertOne(context.Background(), bson.D{})
require.NoError(mt, err, "InsertOne error")

mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"find"},
ErrorCode: 50, // MaxTimeMSExceeded
},
})

err = mt.Coll.FindOne(context.Background(), bson.D{}).Err()

assert.True(mt,
errors.Is(err, context.DeadlineExceeded),
"expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded",
err)
assert.True(mt,
mongo.IsTimeout(err),
"expected error %[1]T(%[1]q) to be a timeout error",
err)
})
// Test that a server-side timeout is a context.DeadlineExceeded
mt.Run("ErrDeadlineWouldBeExceeded", func(mt *mtest.T) {
_, err := mt.Coll.InsertOne(context.Background(), bson.D{})
require.NoError(mt, err, "InsertOne error")

mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"find"},
BlockConnection: true,
BlockTimeMS: 1000,
},
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
err = mt.Coll.FindOne(ctx, bson.D{}).Err()

assert.True(mt,
errors.Is(err, driver.ErrDeadlineWouldBeExceeded),
"expected error %[1]T(%[1]q) to wrap driver.ErrDeadlineWouldBeExceeded",
err)
assert.True(mt,
errors.Is(err, context.DeadlineExceeded),
"expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded",
err)
assert.True(mt,
mongo.IsTimeout(err),
"expected error %[1]T(%[1]q) to be a timeout error",
err)
})
mt.Run("context.DeadlineExceeded", func(mt *mtest.T) {
_, err := mt.Coll.InsertOne(context.Background(), bson.D{})
require.NoError(mt, err, "InsertOne error")

mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"find"},
BlockConnection: true,
BlockTimeMS: 1000,
},
})

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = mt.Coll.FindOne(ctx, bson.D{}).Err()

assert.False(mt,
errors.Is(err, driver.ErrDeadlineWouldBeExceeded),
"expected error %[1]T(%[1]q) to not wrap driver.ErrDeadlineWouldBeExceeded",
err)
assert.True(mt,
errors.Is(err, context.DeadlineExceeded),
"expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded",
err)
assert.True(mt,
mongo.IsTimeout(err),
"expected error %[1]T(%[1]q) to be a timeout error",
err)
})
})
}
10 changes: 8 additions & 2 deletions x/mongo/driver/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ func (e Error) NamespaceNotFound() bool {

// ExtractErrorFromServerResponse extracts an error from a server response bsoncore.Document
// if there is one. Also used in testing for SDAM.
func ExtractErrorFromServerResponse(doc bsoncore.Document) error {
func ExtractErrorFromServerResponse(doc bsoncore.Document, isCSOT bool) error {
var errmsg, codeName string
var code int32
var labels []string
Expand Down Expand Up @@ -514,14 +514,20 @@ func ExtractErrorFromServerResponse(doc bsoncore.Document) error {
errmsg = "command failed"
}

return Error{
err := Error{
Code: code,
Message: errmsg,
Name: codeName,
Labels: labels,
TopologyVersion: tv,
Raw: doc,
}

// TODO: Comment.
if isCSOT && err.Code == 50 {
err.Wrapped = context.DeadlineExceeded
}
return err
}

if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil {
Expand Down
44 changes: 25 additions & 19 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,9 @@ func (op Operation) Execute(ctx context.Context) error {
return err
}

// If no deadline is set on the passed-in context, op.Timeout is set, and context is not already
// a Timeout context, honor op.Timeout in new Timeout context for operation execution.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && op.Timeout != nil && !csot.IsTimeoutContext(ctx) {
// If op.Timeout is set, and context is not already a Timeout context, honor
// op.Timeout in new Timeout context for operation execution.
if op.Timeout != nil && !csot.IsTimeoutContext(ctx) {
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *op.Timeout)
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
Expand Down Expand Up @@ -683,8 +683,9 @@ func (op Operation) Execute(ctx context.Context) error {
first = false
}

// Calculate maxTimeMS value to potentially be appended to the wire message.
maxTimeMS, err := op.calculateMaxTimeMS(ctx, srvr.RTTMonitor().P90(), srvr.RTTMonitor().Stats())
maxTimeMS, err := op.calculateMaxTimeMS(
ctx,
srvr.RTTMonitor().Min())
if err != nil {
return err
}
Expand Down Expand Up @@ -1089,7 +1090,7 @@ func (op Operation) readWireMessage(ctx context.Context, conn Connection) (resul
}

// decode
res, err := op.decodeResult(opcode, rem)
res, err := op.decodeResult(opcode, rem, csot.IsTimeoutContext(ctx))
// Update cluster/operation time and recovery tokens before handling the error to ensure we're properly updating
// everything.
op.updateClusterTimes(res)
Expand Down Expand Up @@ -1557,27 +1558,32 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer)
// return bsoncore.AppendDocumentElement(dst, "$clusterTime", clusterTime)
}

// calculateMaxTimeMS calculates the value of the 'maxTimeMS' field to potentially append
// to the wire message based on the current context's deadline and the 90th percentile RTT
// if the ctx is a Timeout context. If the context is not a Timeout context, it uses the
// operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is
// not a Timeout context, calculateMaxTimeMS returns 0.
func (op Operation) calculateMaxTimeMS(ctx context.Context, rtt90 time.Duration, rttStats string) (uint64, error) {
// calculateMaxTimeMS calculates the value of the 'maxTimeMS' field to
// potentially append to the wire message based on the current context's
// deadline and the min RTT if the ctx is a Timeout context. If the context is
// not a Timeout context, it uses the operation's MaxTimeMS if set. If no
// MaxTimeMS is set on the operation, and context is not a Timeout context,
// calculateMaxTimeMS returns 0.
func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration) (uint64, error) {
if csot.IsTimeoutContext(ctx) {
if deadline, ok := ctx.Deadline(); ok {
remainingTimeout := time.Until(deadline)
maxTime := remainingTimeout - rtt90

maxTime := remainingTimeout - rttMin

// Always round up to the next millisecond value so we never truncate the calculated
// maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms).
maxTimeMS := int64((maxTime + (time.Millisecond - 1)) / time.Millisecond)

if maxTimeMS <= 0 {
return 0, fmt.Errorf(
"remaining time %v until context deadline is less than or equal to 90th percentile RTT: %w\n%v",
"maxTimeMS calculated by context deadline is negative "+
"(remaining time: %v, minimum RTT %v): %w",
remainingTimeout,
ErrDeadlineWouldBeExceeded,
rttStats)
rttMin,
ErrDeadlineWouldBeExceeded)
}

return uint64(maxTimeMS), nil
}
} else if op.MaxTime != nil {
Expand Down Expand Up @@ -1827,7 +1833,7 @@ func (Operation) decodeOpReply(wm []byte) opReply {
return reply
}

func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte) (bsoncore.Document, error) {
func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte, isCSOT bool) (bsoncore.Document, error) {
switch opcode {
case wiremessage.OpReply:
reply := op.decodeOpReply(wm)
Expand All @@ -1845,7 +1851,7 @@ func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte) (bsoncore
return nil, NewCommandResponseError("malformed OP_REPLY: invalid document", err)
}

return rdr, ExtractErrorFromServerResponse(rdr)
return rdr, ExtractErrorFromServerResponse(rdr, isCSOT)
case wiremessage.OpMsg:
_, wm, ok := wiremessage.ReadMsgFlags(wm)
if !ok {
Expand Down Expand Up @@ -1882,7 +1888,7 @@ func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte) (bsoncore
return nil, NewCommandResponseError("malformed OP_MSG: invalid document", err)
}

return res, ExtractErrorFromServerResponse(res)
return res, ExtractErrorFromServerResponse(res, isCSOT)
default:
return nil, fmt.Errorf("cannot decode result from %s", opcode)
}
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func TestOperation(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got, err := tc.op.calculateMaxTimeMS(tc.ctx, tc.rtt90, "")
got, err := tc.op.calculateMaxTimeMS(tc.ctx, tc.rtt90)

// Assert that the calculated maxTimeMS is less than or equal to the expected value. A few
// milliseconds will have elapsed toward the context deadline, and (remainingTimeout
Expand Down

0 comments on commit 22b758d

Please sign in to comment.