Skip to content

Commit

Permalink
server: with TLS, set TCP user timeout on the underlying raw connecti…
Browse files Browse the repository at this point in the history
…on (#5646) (#6321)
  • Loading branch information
tobotg committed Jun 27, 2023
1 parent 1634254 commit e859984
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 25 deletions.
2 changes: 1 addition & 1 deletion internal/transport/http2_server.go
Expand Up @@ -238,7 +238,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
kp.Timeout = defaultServerKeepaliveTimeout
}
if kp.Time != infinity {
if err = syscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil {
if err = syscall.SetTCPUserTimeout(rawConn, kp.Timeout); err != nil {
return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err)
}
}
Expand Down
119 changes: 98 additions & 21 deletions internal/transport/keepalive_test.go
Expand Up @@ -24,18 +24,23 @@ package transport

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"os"
"strings"
"testing"
"time"

"golang.org/x/net/http2"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/testdata"
)

const defaultTestTimeout = 10 * time.Second
Expand Down Expand Up @@ -581,47 +586,82 @@ func (s) TestKeepaliveServerEnforcementWithDormantKeepaliveOnClient(t *testing.T
// the keepalive timeout, as detailed in proposal A18.
func (s) TestTCPUserTimeout(t *testing.T) {
tests := []struct {
tls bool
time time.Duration
timeout time.Duration
clientWantTimeout time.Duration
serverWantTimeout time.Duration
}{
{
false,
10 * time.Second,
10 * time.Second,
10 * 1000 * time.Millisecond,
10 * 1000 * time.Millisecond,
},
{
false,
0,
0,
0,
20 * 1000 * time.Millisecond,
},
{
false,
infinity,
infinity,
0,
0,
},
{
true,
10 * time.Second,
10 * time.Second,
10 * 1000 * time.Millisecond,
10 * 1000 * time.Millisecond,
},
{
true,
0,
0,
0,
20 * 1000 * time.Millisecond,
},
{
true,
infinity,
infinity,
0,
0,
},
}
for _, tt := range tests {
sopts := &ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
Time: tt.time,
Timeout: tt.timeout,
},
}

copts := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: tt.time,
Timeout: tt.timeout,
},
}

if tt.tls {
copts.TransportCredentials = makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
sopts.Credentials = makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")

}

server, client, cancel := setUpWithOptions(
t,
0,
&ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
Time: tt.time,
Timeout: tt.timeout,
},
},
sopts,
normal,
ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: tt.time,
Timeout: tt.timeout,
},
},
copts,
)
defer func() {
client.Close(fmt.Errorf("closed manually by test"))
Expand All @@ -630,6 +670,7 @@ func (s) TestTCPUserTimeout(t *testing.T) {
}()

var sc *http2Server
var srawConn net.Conn
// Wait until the server transport is setup.
for {
server.mu.Lock()
Expand All @@ -644,6 +685,7 @@ func (s) TestTCPUserTimeout(t *testing.T) {
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", k)
}
srawConn = server.conns[k]
}
server.mu.Unlock()
break
Expand All @@ -657,25 +699,60 @@ func (s) TestTCPUserTimeout(t *testing.T) {
}
client.CloseStream(stream, io.EOF)

cltOpt, err := syscall.GetTCPUserTimeout(client.conn)
if err != nil {
t.Fatalf("syscall.GetTCPUserTimeout() failed: %v", err)
// check client TCP user timeout only when non TLS
// TODO : find a way to get the underlying conn for client when TLS
if !tt.tls {
cltOpt, err := syscall.GetTCPUserTimeout(client.conn)
if err != nil {
t.Fatalf("syscall.GetTCPUserTimeout() failed: %v", err)
}
if cltOpt < 0 {
t.Skipf("skipping test on unsupported environment")
}
if gotTimeout := time.Duration(cltOpt) * time.Millisecond; gotTimeout != tt.clientWantTimeout {
t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.clientWantTimeout)
}
}
if cltOpt < 0 {
t.Skipf("skipping test on unsupported environment")
scConn := sc.conn
if tt.tls {
if _, ok := sc.conn.(*net.TCPConn); ok {
t.Fatalf("sc.conn is should have wrapped conn with TLS")
}
scConn = srawConn
}
if gotTimeout := time.Duration(cltOpt) * time.Millisecond; gotTimeout != tt.clientWantTimeout {
t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.clientWantTimeout)
// verify the type of scConn (on which TCP user timeout will be got)
if _, ok := scConn.(*net.TCPConn); !ok {
t.Fatalf("server underlying conn is of type %T, want net.TCPConn", scConn)
}

srvOpt, err := syscall.GetTCPUserTimeout(sc.conn)
srvOpt, err := syscall.GetTCPUserTimeout(scConn)
if err != nil {
t.Fatalf("syscall.GetTCPUserTimeout() failed: %v", err)
}
if gotTimeout := time.Duration(srvOpt) * time.Millisecond; gotTimeout != tt.serverWantTimeout {
t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.serverWantTimeout)
}

}
}

func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials {
cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
if err != nil {
t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err)
}
b, err := os.ReadFile(testdata.Path(rootsPath))
if err != nil {
t.Fatalf("os.ReadFile(%q) failed: %v", rootsPath, err)
}
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(b) {
t.Fatal("failed to append certificates")
}
return credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: roots,
InsecureSkipVerify: true,
})
}

// checkForHealthyStream attempts to create a stream and return error if any.
Expand Down
7 changes: 4 additions & 3 deletions internal/transport/transport_test.go
Expand Up @@ -297,7 +297,7 @@ type server struct {
port string
startedErr chan error // error (or nil) with server start value
mu sync.Mutex
conns map[ServerTransport]bool
conns map[ServerTransport]net.Conn
h *testStreamHandler
ready chan struct{}
channelzID *channelz.Identifier
Expand Down Expand Up @@ -329,13 +329,14 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
return
}
s.port = p
s.conns = make(map[ServerTransport]bool)
s.conns = make(map[ServerTransport]net.Conn)
s.startedErr <- nil
for {
conn, err := s.lis.Accept()
if err != nil {
return
}
rawConn := conn
transport, err := NewServerTransport(conn, serverConfig)
if err != nil {
return
Expand All @@ -346,7 +347,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
transport.Close(errors.New("s.conns is nil"))
return
}
s.conns[transport] = true
s.conns[transport] = rawConn
h := &testStreamHandler{t: transport.(*http2Server)}
s.h = h
s.mu.Unlock()
Expand Down

0 comments on commit e859984

Please sign in to comment.