Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: with TLS, set TCP user timeout on the underlying raw connection (#5646) #6321

Merged
merged 2 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
kp.Timeout = defaultServerKeepaliveTimeout
}
if kp.Time != infinity {
if err = syscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil {
if err = syscall.SetTCPUserTimeout(rawConn, kp.Timeout); err != nil {
return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err)
}
}
Expand Down
119 changes: 98 additions & 21 deletions internal/transport/keepalive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,23 @@ package transport

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"os"
"strings"
"testing"
"time"

"golang.org/x/net/http2"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/testdata"
)

const defaultTestTimeout = 10 * time.Second
Expand Down Expand Up @@ -581,47 +586,82 @@ func (s) TestKeepaliveServerEnforcementWithDormantKeepaliveOnClient(t *testing.T
// the keepalive timeout, as detailed in proposal A18.
func (s) TestTCPUserTimeout(t *testing.T) {
tests := []struct {
tls bool
time time.Duration
timeout time.Duration
clientWantTimeout time.Duration
serverWantTimeout time.Duration
}{
{
false,
10 * time.Second,
10 * time.Second,
10 * 1000 * time.Millisecond,
10 * 1000 * time.Millisecond,
},
{
false,
0,
0,
0,
20 * 1000 * time.Millisecond,
},
{
false,
infinity,
infinity,
0,
0,
},
{
true,
10 * time.Second,
10 * time.Second,
10 * 1000 * time.Millisecond,
10 * 1000 * time.Millisecond,
},
{
true,
0,
0,
0,
20 * 1000 * time.Millisecond,
},
{
true,
infinity,
infinity,
0,
0,
},
}
for _, tt := range tests {
sopts := &ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
Time: tt.time,
Timeout: tt.timeout,
},
}

copts := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: tt.time,
Timeout: tt.timeout,
},
}

if tt.tls {
copts.TransportCredentials = makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
sopts.Credentials = makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")

}

server, client, cancel := setUpWithOptions(
t,
0,
&ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
Time: tt.time,
Timeout: tt.timeout,
},
},
sopts,
normal,
ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: tt.time,
Timeout: tt.timeout,
},
},
copts,
)
defer func() {
client.Close(fmt.Errorf("closed manually by test"))
Expand All @@ -630,6 +670,7 @@ func (s) TestTCPUserTimeout(t *testing.T) {
}()

var sc *http2Server
var srawConn net.Conn
// Wait until the server transport is setup.
for {
server.mu.Lock()
Expand All @@ -644,6 +685,7 @@ func (s) TestTCPUserTimeout(t *testing.T) {
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", k)
}
srawConn = server.conns[k]
}
server.mu.Unlock()
break
Expand All @@ -657,25 +699,60 @@ func (s) TestTCPUserTimeout(t *testing.T) {
}
client.CloseStream(stream, io.EOF)

cltOpt, err := syscall.GetTCPUserTimeout(client.conn)
if err != nil {
t.Fatalf("syscall.GetTCPUserTimeout() failed: %v", err)
// check client TCP user timeout only when non TLS
// TODO : find a way to get the underlying conn for client when TLS
if !tt.tls {
cltOpt, err := syscall.GetTCPUserTimeout(client.conn)
if err != nil {
t.Fatalf("syscall.GetTCPUserTimeout() failed: %v", err)
}
if cltOpt < 0 {
t.Skipf("skipping test on unsupported environment")
}
if gotTimeout := time.Duration(cltOpt) * time.Millisecond; gotTimeout != tt.clientWantTimeout {
t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.clientWantTimeout)
}
}
if cltOpt < 0 {
t.Skipf("skipping test on unsupported environment")
scConn := sc.conn
if tt.tls {
if _, ok := sc.conn.(*net.TCPConn); ok {
t.Fatalf("sc.conn is should have wrapped conn with TLS")
}
scConn = srawConn
}
if gotTimeout := time.Duration(cltOpt) * time.Millisecond; gotTimeout != tt.clientWantTimeout {
t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.clientWantTimeout)
// verify the type of scConn (on which TCP user timeout will be got)
if _, ok := scConn.(*net.TCPConn); !ok {
t.Fatalf("server underlying conn is of type %T, want net.TCPConn", scConn)
}

srvOpt, err := syscall.GetTCPUserTimeout(sc.conn)
srvOpt, err := syscall.GetTCPUserTimeout(scConn)
if err != nil {
t.Fatalf("syscall.GetTCPUserTimeout() failed: %v", err)
}
if gotTimeout := time.Duration(srvOpt) * time.Millisecond; gotTimeout != tt.serverWantTimeout {
t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.serverWantTimeout)
}

}
}

func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials {
cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
if err != nil {
t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err)
}
b, err := os.ReadFile(testdata.Path(rootsPath))
if err != nil {
t.Fatalf("os.ReadFile(%q) failed: %v", rootsPath, err)
}
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(b) {
t.Fatal("failed to append certificates")
}
return credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: roots,
InsecureSkipVerify: true,
})
}

// checkForHealthyStream attempts to create a stream and return error if any.
Expand Down
7 changes: 4 additions & 3 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ type server struct {
port string
startedErr chan error // error (or nil) with server start value
mu sync.Mutex
conns map[ServerTransport]bool
conns map[ServerTransport]net.Conn
h *testStreamHandler
ready chan struct{}
channelzID *channelz.Identifier
Expand Down Expand Up @@ -329,13 +329,14 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
return
}
s.port = p
s.conns = make(map[ServerTransport]bool)
s.conns = make(map[ServerTransport]net.Conn)
s.startedErr <- nil
for {
conn, err := s.lis.Accept()
if err != nil {
return
}
rawConn := conn
transport, err := NewServerTransport(conn, serverConfig)
if err != nil {
return
Expand All @@ -346,7 +347,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
transport.Close(errors.New("s.conns is nil"))
return
}
s.conns[transport] = true
s.conns[transport] = rawConn
h := &testStreamHandler{t: transport.(*http2Server)}
s.h = h
s.mu.Unlock()
Expand Down