Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3172 Read response in the background after an op timeout. #1589

Merged
merged 11 commits into from
Apr 12, 2024
Merged
6 changes: 4 additions & 2 deletions internal/csot/csot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ type timeoutKey struct{}
// TODO default behavior.
func MakeTimeoutContext(ctx context.Context, to time.Duration) (context.Context, context.CancelFunc) {
// Only use the passed in Duration as a timeout on the Context if it
// is non-zero.
// is non-zero and if the Context doesn't already have a timeout.
cancelFunc := func() {}
if to != 0 {
if _, deadlineSet := ctx.Deadline(); to != 0 && !deadlineSet {
ctx, cancelFunc = context.WithTimeout(ctx, to)
}

// Add timeoutKey either way to indicate CSOT is enabled.
return context.WithValue(ctx, timeoutKey{}, true), cancelFunc
}

Expand Down
2 changes: 1 addition & 1 deletion mongo/change_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err
// If no deadline is set on the passed-in context, cs.client.timeout is set, and context is not already
// a Timeout context, honor cs.client.timeout in new Timeout context for change stream operation execution
// and potential retry.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && cs.client.timeout != nil && !csot.IsTimeoutContext(ctx) {
if cs.client.timeout != nil && !csot.IsTimeoutContext(ctx) {
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *cs.client.timeout)
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
Expand Down
25 changes: 24 additions & 1 deletion mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,25 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i
// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/.
func (coll *Collection) Find(ctx context.Context, filter interface{},
opts ...*options.FindOptions) (cur *Cursor, err error) {
// Now that "maxTimeMS" is actually set when CSOT is enabled, even when
// using a context-with-deadline, more users may encounter the case where
// setting "maxTimeMS" on a find/aggregate command will limit the lifetime
// of the cursor to that deadline, which is often unexpected. The current
// proposed improvement in DRIVERS-2722 is to omit "maxTimeMS" on Find and
// Aggregate operations (not FindOne). To maintain the existing behavior,
// include "maxTimeMS" when only "timeoutMS" is set.
_, deadlineSet := ctx.Deadline()
setMaxTimeMS := !deadlineSet && coll.client.timeout != nil

return coll.find(ctx, filter, setMaxTimeMS, opts...)
}

func (coll *Collection) find(
ctx context.Context,
filter interface{},
setMaxTimeMS bool,
matthewdale marked this conversation as resolved.
Show resolved Hide resolved
opts ...*options.FindOptions,
) (cur *Cursor, err error) {

if ctx == nil {
ctx = context.Background()
Expand Down Expand Up @@ -1232,6 +1251,10 @@ func (coll *Collection) Find(ctx context.Context, filter interface{},
Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI).
Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger)

if setMaxTimeMS {
op = op.Timeout(coll.client.timeout)
}

cursorOpts := coll.client.createBaseCursorOptions()

cursorOpts.MarshalValueEncoderFn = newEncoderFn(coll.bsonOpts, coll.registry)
Expand Down Expand Up @@ -1408,7 +1431,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{},
// by the server.
findOpts = append(findOpts, options.Find().SetLimit(-1))

cursor, err := coll.Find(ctx, filter, findOpts...)
cursor, err := coll.find(ctx, filter, true, findOpts...)
return &SingleResult{
ctx: ctx,
cur: cursor,
Expand Down
4 changes: 2 additions & 2 deletions mongo/gridfs/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func (b *Bucket) DeleteContext(ctx context.Context, fileID interface{}) error {
// If no deadline is set on the passed-in context, Timeout is set on the Client, and context is
// not already a Timeout context, honor Timeout in new Timeout context for operation execution to
// be shared by both delete operations.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && b.db.Client().Timeout() != nil && !csot.IsTimeoutContext(ctx) {
if b.db.Client().Timeout() != nil && !csot.IsTimeoutContext(ctx) {
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *b.db.Client().Timeout())
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
Expand Down Expand Up @@ -387,7 +387,7 @@ func (b *Bucket) DropContext(ctx context.Context) error {
// If no deadline is set on the passed-in context, Timeout is set on the Client, and context is
// not already a Timeout context, honor Timeout in new Timeout context for operation execution to
// be shared by both drop operations.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && b.db.Client().Timeout() != nil && !csot.IsTimeoutContext(ctx) {
if b.db.Client().Timeout() != nil && !csot.IsTimeoutContext(ctx) {
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *b.db.Client().Timeout())
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
Expand Down
136 changes: 136 additions & 0 deletions mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package integration

import (
"context"
"errors"
"fmt"
"net"
"os"
Expand All @@ -19,6 +20,7 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/assert"
Expand Down Expand Up @@ -1006,3 +1008,137 @@ func TestClientStress(t *testing.T) {
}
})
}

func TestCSOT(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))

csotOpts := mtest.NewOptions().ClientOptions(options.Client().SetTimeout(10 * time.Second))
mt.RunOpts("includes maxTimeMS if CSOT timeout is set", csotOpts, func(mt *mtest.T) {
mt.Run("with context.Background", func(mt *mtest.T) {
_, err := mt.Coll.InsertOne(context.Background(), bson.D{})
require.NoError(mt, err, "InsertOne error")

maxTimeVal := mt.GetStartedEvent().Command.Lookup("maxTimeMS")

require.True(mt, len(maxTimeVal.Value) > 0, "expected maxTimeMS BSON value to be non-empty")
require.Equal(mt, maxTimeVal.Type, bsontype.Int64, "expected maxTimeMS value to be type Int64")

maxTimeMS := maxTimeVal.Int64()
assert.True(mt, maxTimeMS > 0, "expected maxTimeMS value to be greater than 0")
})
mt.Run("with context.WithTimeout", func(mt *mtest.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()

_, err := mt.Coll.InsertOne(ctx, bson.D{})
require.NoError(mt, err, "InsertOne error")

maxTimeVal := mt.GetStartedEvent().Command.Lookup("maxTimeMS")
require.True(mt, len(maxTimeVal.Value) > 0, "expected maxTimeMS BSON value to be non-empty")
require.Equal(mt, maxTimeVal.Type, bsontype.Int64, "expected maxTimeMS value to be type Int64")

maxTimeMS := maxTimeVal.Int64()
assert.True(mt,
maxTimeMS > 60_000,
"expected maxTimeMS value to be greater than 60000, but got %v",
maxTimeMS)
})
})

mt.RunOpts("timeout errors wrap context.DeadlineExceeded", csotOpts, func(mt *mtest.T) {
// Test that a client-side timeout is a context.DeadlineExceeded
mt.Run("MaxTimeMSExceeded", func(mt *mtest.T) {
matthewdale marked this conversation as resolved.
Show resolved Hide resolved
_, err := mt.Coll.InsertOne(context.Background(), bson.D{})
require.NoError(mt, err, "InsertOne error")

mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"find"},
ErrorCode: 50, // MaxTimeMSExceeded
},
})

err = mt.Coll.FindOne(context.Background(), bson.D{}).Err()

assert.True(mt,
errors.Is(err, context.DeadlineExceeded),
"expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded",
err)
assert.True(mt,
mongo.IsTimeout(err),
"expected error %[1]T(%[1]q) to be a timeout error",
err)
})
// Test that a server-side timeout is a context.DeadlineExceeded
mt.Run("ErrDeadlineWouldBeExceeded", func(mt *mtest.T) {
_, err := mt.Coll.InsertOne(context.Background(), bson.D{})
require.NoError(mt, err, "InsertOne error")

mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"find"},
BlockConnection: true,
BlockTimeMS: 1000,
},
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
err = mt.Coll.FindOne(ctx, bson.D{}).Err()

assert.True(mt,
errors.Is(err, driver.ErrDeadlineWouldBeExceeded),
"expected error %[1]T(%[1]q) to wrap driver.ErrDeadlineWouldBeExceeded",
err)
assert.True(mt,
errors.Is(err, context.DeadlineExceeded),
"expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded",
err)
assert.True(mt,
mongo.IsTimeout(err),
"expected error %[1]T(%[1]q) to be a timeout error",
err)
})
mt.Run("context.DeadlineExceeded", func(mt *mtest.T) {
_, err := mt.Coll.InsertOne(context.Background(), bson.D{})
require.NoError(mt, err, "InsertOne error")

mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"find"},
BlockConnection: true,
BlockTimeMS: 1000,
},
})

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = mt.Coll.FindOne(ctx, bson.D{}).Err()

assert.False(mt,
errors.Is(err, driver.ErrDeadlineWouldBeExceeded),
"expected error %[1]T(%[1]q) to not wrap driver.ErrDeadlineWouldBeExceeded",
err)
assert.True(mt,
errors.Is(err, context.DeadlineExceeded),
"expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded",
err)
assert.True(mt,
mongo.IsTimeout(err),
"expected error %[1]T(%[1]q) to be a timeout error",
err)
})
})
}
20 changes: 18 additions & 2 deletions x/mongo/driver/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ func (e Error) NamespaceNotFound() bool {

// ExtractErrorFromServerResponse extracts an error from a server response bsoncore.Document
// if there is one. Also used in testing for SDAM.
func ExtractErrorFromServerResponse(doc bsoncore.Document) error {
//
// Set isCSOT to true if "timeoutMS" is set on the Client.
func ExtractErrorFromServerResponse(doc bsoncore.Document, isCSOT bool) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest not using a bool here and instead pass in the context and use csot.IsTimeoutContext(ctx) within decodeResult. This ensures that the source of CSOT-ness is derived from a context.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will do.

var errmsg, codeName string
var code int32
var labels []string
Expand Down Expand Up @@ -514,14 +516,28 @@ func ExtractErrorFromServerResponse(doc bsoncore.Document) error {
errmsg = "command failed"
}

return Error{
err := Error{
Code: code,
Message: errmsg,
Name: codeName,
Labels: labels,
TopologyVersion: tv,
Raw: doc,
}

// If CSOT is enabled and we get a MaxTimeMSExpired error, assume that
// the error was caused by setting "maxTimeMS" on the command based on
// the context deadline or on "timeoutMS". In that case, make the error
// wrap context.DeadlineExceeded so that users can always check
//
// errors.Is(err, context.DeadlineExceeded)
//
// for either client-side or server-side timeouts.
if isCSOT && err.Code == 50 {
err.Wrapped = context.DeadlineExceeded
}

return err
}

if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil {
Expand Down
27 changes: 14 additions & 13 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,9 @@ func (op Operation) Execute(ctx context.Context) error {
return err
}

// If no deadline is set on the passed-in context, op.Timeout is set, and context is not already
// a Timeout context, honor op.Timeout in new Timeout context for operation execution.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && op.Timeout != nil && !csot.IsTimeoutContext(ctx) {
// If op.Timeout is set, and context is not already a Timeout context, honor
// op.Timeout in new Timeout context for operation execution.
if op.Timeout != nil && !csot.IsTimeoutContext(ctx) {
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *op.Timeout)
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
Expand Down Expand Up @@ -683,8 +683,7 @@ func (op Operation) Execute(ctx context.Context) error {
first = false
}

// Calculate maxTimeMS value to potentially be appended to the wire message.
maxTimeMS, err := op.calculateMaxTimeMS(ctx, srvr.RTTMonitor().P90(), srvr.RTTMonitor().Stats())
maxTimeMS, err := op.calculateMaxTimeMS(ctx, srvr.RTTMonitor().P90())
if err != nil {
return err
}
Expand Down Expand Up @@ -1089,7 +1088,7 @@ func (op Operation) readWireMessage(ctx context.Context, conn Connection) (resul
}

// decode
res, err := op.decodeResult(opcode, rem)
res, err := op.decodeResult(opcode, rem, csot.IsTimeoutContext(ctx))
// Update cluster/operation time and recovery tokens before handling the error to ensure we're properly updating
// everything.
op.updateClusterTimes(res)
Expand Down Expand Up @@ -1562,7 +1561,7 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer)
// if the ctx is a Timeout context. If the context is not a Timeout context, it uses the
// operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is
// not a Timeout context, calculateMaxTimeMS returns 0.
func (op Operation) calculateMaxTimeMS(ctx context.Context, rtt90 time.Duration, rttStats string) (uint64, error) {
func (op Operation) calculateMaxTimeMS(ctx context.Context, rtt90 time.Duration) (uint64, error) {
if csot.IsTimeoutContext(ctx) {
if deadline, ok := ctx.Deadline(); ok {
remainingTimeout := time.Until(deadline)
Expand All @@ -1573,11 +1572,13 @@ func (op Operation) calculateMaxTimeMS(ctx context.Context, rtt90 time.Duration,
maxTimeMS := int64((maxTime + (time.Millisecond - 1)) / time.Millisecond)
if maxTimeMS <= 0 {
return 0, fmt.Errorf(
"remaining time %v until context deadline is less than or equal to 90th percentile RTT: %w\n%v",
"maxTimeMS calculated by context deadline is negative "+
"(remaining time: %v, 90th percentile RTT %v): %w",
remainingTimeout,
ErrDeadlineWouldBeExceeded,
rttStats)
rtt90,
ErrDeadlineWouldBeExceeded)
}

return uint64(maxTimeMS), nil
}
} else if op.MaxTime != nil {
Expand Down Expand Up @@ -1827,7 +1828,7 @@ func (Operation) decodeOpReply(wm []byte) opReply {
return reply
}

func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte) (bsoncore.Document, error) {
func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte, isCSOT bool) (bsoncore.Document, error) {
matthewdale marked this conversation as resolved.
Show resolved Hide resolved
switch opcode {
case wiremessage.OpReply:
reply := op.decodeOpReply(wm)
Expand All @@ -1845,7 +1846,7 @@ func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte) (bsoncore
return nil, NewCommandResponseError("malformed OP_REPLY: invalid document", err)
}

return rdr, ExtractErrorFromServerResponse(rdr)
return rdr, ExtractErrorFromServerResponse(rdr, isCSOT)
case wiremessage.OpMsg:
_, wm, ok := wiremessage.ReadMsgFlags(wm)
if !ok {
Expand Down Expand Up @@ -1882,7 +1883,7 @@ func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte) (bsoncore
return nil, NewCommandResponseError("malformed OP_MSG: invalid document", err)
}

return res, ExtractErrorFromServerResponse(res)
return res, ExtractErrorFromServerResponse(res, isCSOT)
default:
return nil, fmt.Errorf("cannot decode result from %s", opcode)
}
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func TestOperation(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got, err := tc.op.calculateMaxTimeMS(tc.ctx, tc.rtt90, "")
got, err := tc.op.calculateMaxTimeMS(tc.ctx, tc.rtt90)

// Assert that the calculated maxTimeMS is less than or equal to the expected value. A few
// milliseconds will have elapsed toward the context deadline, and (remainingTimeout
Expand Down