Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3054 Handshake connection should not use legacy for LB #1482

Merged
merged 11 commits into from
Dec 4, 2023
Merged
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ evg-test-load-balancers:
go test $(BUILD_TAGS) ./mongo/integration -run TestChangeStreamSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestInitialDNSSeedlistDiscoverySpec/load_balanced -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancedConnectionHandshake -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite

.PHONY: evg-test-search-index
Expand Down
21 changes: 0 additions & 21 deletions mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -768,27 +768,6 @@ func TestClient(t *testing.T) {
"expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String())
}
})

// Test that OP_MSG is used for handshakes when loadBalanced is true.
opMsgLBOpts := mtest.NewOptions().ClientType(mtest.Proxy).MinServerVersion("5.0").Topologies(mtest.LoadBalanced)
mt.RunOpts("OP_MSG used for handshakes when loadBalanced is true", opMsgLBOpts, func(mt *mtest.T) {
err := mt.Client.Ping(context.Background(), mtest.PrimaryRp)
assert.Nil(mt, err, "Ping error: %v", err)

msgPairs := mt.GetProxiedMessages()
assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs))

// First three messages should be connection handshakes: one for the heartbeat connection, another for the
// application connection, and a final one for the RTT monitor connection.
for idx, pair := range msgPairs[:3] {
assert.Equal(mt, "hello", pair.CommandName, "expected command name 'hello' at index %d, got %s", idx,
pair.CommandName)

// Assert that appended OpCode is OP_MSG when loadBalanced is true.
assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode,
"expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String())
}
})
}

func TestClient_BSONOptions(t *testing.T) {
Expand Down
51 changes: 51 additions & 0 deletions mongo/integration/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/version"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)

func TestHandshakeProse(t *testing.T) {
Expand Down Expand Up @@ -199,3 +200,53 @@ func TestHandshakeProse(t *testing.T) {
})
}
}

func TestLoadBalancedConnectionHandshake(t *testing.T) {
mt := mtest.New(t)

lbopts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
mtest.LoadBalanced)

mt.RunOpts("LB connection handshake uses OP_MSG", lbopts, func(mt *mtest.T) {
// Ping the server to ensure the handshake has completed.
err := mt.Client.Ping(context.Background(), nil)
require.NoError(mt, err, "Ping error: %v", err)

messages := mt.GetProxiedMessages()
handshakeMessage := messages[:1][0]

// Per the specifications, if loadBalanced=true, drivers MUST use the hello
// command for the initial handshake and use the OP_MSG protocol.
assert.Equal(mt, "hello", handshakeMessage.CommandName)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blink1073 @qingyang-hu The handshake specifications say that "If a server API version is requested or loadBalanced: True, drivers MUST use the hello command for the initial handshake and use the OP_MSG protocol." So we should always assert that "hello" is used as the command name in this case.

Changing this lead to the discovery that we were also setting the command name incorrectly on the initial handshake on LB'd servers. I've updated that logic in hello.go and relevant SDAM tests per DRIVERS-1929 .

assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode)
})

opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
mtest.ReplicaSet,
mtest.Sharded,
mtest.Single,
mtest.ShardedReplicaSet)

mt.RunOpts("non-LB connection handshake uses OP_QUERY", opts, func(mt *mtest.T) {
// Ping the server to ensure the handshake has completed.
err := mt.Client.Ping(context.Background(), nil)
require.NoError(mt, err, "Ping error: %v", err)

messages := mt.GetProxiedMessages()
handshakeMessage := messages[:1][0]

want := wiremessage.OpQuery

hello := handshake.LegacyHello
if os.Getenv("REQUIRE_API_VERSION") == "true" {
hello = "hello"

// If the server API version is requested, then we should use OP_MSG
// regardless of the topology
want = wiremessage.OpMsg
}

assert.Equal(mt, hello, handshakeMessage.CommandName)
assert.Equal(mt, want, handshakeMessage.Sent.OpCode)
})
}
3 changes: 2 additions & 1 deletion testdata/load-balancers/sdam-error-handling.json
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@
},
"data": {
"failCommands": [
"isMaster"
"isMaster",
"hello"
],
"closeConnection": true,
"appName": "lbSDAMErrorTestClient"
Expand Down
2 changes: 1 addition & 1 deletion testdata/load-balancers/sdam-error-handling.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ tests:
configureFailPoint: failCommand
mode: { times: 1 }
data:
failCommands: [isMaster]
failCommands: [isMaster, hello]
closeConnection: true
appName: *singleClientAppName
- name: insertOne
Expand Down
10 changes: 5 additions & 5 deletions x/mongo/driver/operation/hello.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ func (h *Hello) handshakeCommand(dst []byte, desc description.SelectedServer) ([
func (h *Hello) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
// Use "hello" if topology is LoadBalanced, API version is declared or server
// has responded with "helloOk". Otherwise, use legacy hello.
if desc.Kind == description.LoadBalanced || h.serverAPI != nil || desc.Server.HelloOK {
if h.loadBalanced || h.serverAPI != nil || desc.Server.HelloOK {
dst = bsoncore.AppendInt32Element(dst, "hello", 1)
} else {
dst = bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1)
Expand Down Expand Up @@ -575,8 +575,8 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti
// loadBalanced is False. If this is the case, then the drivers MUST use legacy
// hello for the first message of the initial handshake with the OP_QUERY
// protocol
func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool {
return srvAPI == nil && deployment.Kind() != description.LoadBalanced
func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, loadbalanced bool) bool {
return srvAPI == nil && !loadbalanced
}

func (h *Hello) createOperation() driver.Operation {
Expand All @@ -592,7 +592,7 @@ func (h *Hello) createOperation() driver.Operation {
ServerAPI: h.serverAPI,
}

if isLegacyHandshake(h.serverAPI, h.d) {
if isLegacyHandshake(h.serverAPI, h.loadBalanced) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

h.loadBalanced is set through the client options / URI query parameter.

op.Legacy = driver.LegacyHandshake
}

Expand All @@ -616,7 +616,7 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address,
ServerAPI: h.serverAPI,
}

if isLegacyHandshake(h.serverAPI, deployment) {
if isLegacyHandshake(h.serverAPI, h.loadBalanced) {
op.Legacy = driver.LegacyHandshake
}

Expand Down