From 2cfbc432f7a56fdbe127b34eb496ddbf1fc7c0ef Mon Sep 17 00:00:00 2001 From: mmsqe Date: Tue, 13 Jun 2023 19:38:58 +0800 Subject: [PATCH] rpc: add limit for batch request items and response size (#26681) This PR adds server-side limits for JSON-RPC batch requests. Before this change, batches were limited only by processing time. The server would pick calls from the batch and answer them until the response timeout occurred, then stop processing the remaining batch items. Here, we are adding two additional limits which can be configured: - the 'item limit': batches can have at most N items - the 'response size limit': batches can contain at most X response bytes These limits are optional in package rpc. In Geth, we set a default limit of 1000 items and 25MB response size. When a batch goes over the limit, an error response is returned to the client. However, doing this correctly isn't always possible. In JSON-RPC, only method calls with a valid `id` can be responded to. Since batches may also contain non-call messages or notifications, the best effort thing we can do to report an error with the batch itself is reporting the limit violation as an error for the first method call in the batch. If a batch is too large, but contains only notifications and responses, the error will be reported with a null `id`. The RPC client was also changed so it can deal with errors resulting from too large batches. An older client connected to the server code in this PR could get stuck until the request timeout occurred when the batch is too large. **Upgrading to a version of the RPC client containing this change is strongly recommended to avoid timeout issues.** For some weird reason, when writing the original client implementation, @fjl worked off of the assumption that responses could be distributed across batches arbitrarily. So for a batch request containing requests `[A B C]`, the server could respond with `[A B C]` but also with `[A B] [C]` or even `[A] [B] [C]` and it wouldn't make a difference to the client. So in the implementation of BatchCallContext, the client waited for all requests in the batch individually. If the server didn't respond to some of the requests in the batch, the client would eventually just time out (if a context was used). With the addition of batch limits into the server, we anticipate that people will hit this kind of error way more often. To handle this properly, the client now waits for a single response batch and expects it to contain all responses to the requests. --------- Co-authored-by: Felix Lange Co-authored-by: Martin Holst Swende --- cmd/clef/main.go | 1 + cmd/geth/main.go | 2 + cmd/utils/flags.go | 18 ++ node/api.go | 8 + node/config.go | 6 + node/defaults.go | 24 +- node/node.go | 31 +- node/rpcstack.go | 17 +- node/rpcstack_test.go | 6 +- rpc/client.go | 157 +++++++---- rpc/client_opt.go | 135 +++++++++ rpc/client_test.go | 5 +- rpc/errors.go | 22 +- rpc/handler.go | 375 ++++++++++++++++++------- rpc/http.go | 95 +++++-- rpc/inproc.go | 3 +- rpc/ipc.go | 9 +- rpc/json.go | 25 +- rpc/server.go | 32 ++- rpc/server_test.go | 39 +++ rpc/stdio.go | 9 +- rpc/subscription.go | 2 +- rpc/testdata/invalid-batch-toolarge.js | 13 + rpc/types.go | 2 +- rpc/websocket.go | 80 ++++-- 25 files changed, 869 insertions(+), 247 deletions(-) create mode 100644 rpc/client_opt.go create mode 100644 rpc/testdata/invalid-batch-toolarge.js diff --git a/cmd/clef/main.go b/cmd/clef/main.go index b1ffa38ffefaa..82abf9d7b7bf2 100644 --- a/cmd/clef/main.go +++ b/cmd/clef/main.go @@ -656,6 +656,7 @@ func signer(c *cli.Context) error { cors := utils.SplitAndTrim(c.GlobalString(utils.HTTPCORSDomainFlag.Name)) srv := rpc.NewServer() + srv.SetBatchLimits(node.DefaultConfig.BatchRequestLimit, node.DefaultConfig.BatchResponseMaxSize) err := node.RegisterApis(rpcAPI, []string{"account"}, srv, false) if err != nil { utils.Fatalf("Could not register API: %w", err) diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 76d6427fabdfa..0351809bb12b4 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -183,6 +183,8 @@ var ( utils.RPCGlobalEVMTimeoutFlag, utils.RPCGlobalTxFeeCapFlag, utils.AllowUnprotectedTxs, + utils.BatchRequestLimit, + utils.BatchResponseMaxSize, } metricsFlags = []cli.Flag{ diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 907e3ce916776..9df13f70a7ed9 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -668,6 +668,16 @@ var ( Name: "rpc.allow-unprotected-txs", Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC", } + BatchRequestLimit = &cli.IntFlag{ + Name: "rpc.batch-request-limit", + Usage: "Maximum number of requests in a batch", + Value: node.DefaultConfig.BatchRequestLimit, + } + BatchResponseMaxSize = &cli.IntFlag{ + Name: "rpc.batch-response-max-size", + Usage: "Maximum number of bytes returned from a batched call", + Value: node.DefaultConfig.BatchResponseMaxSize, + } // Network Settings MaxPeersFlag = cli.IntFlag{ @@ -1056,6 +1066,14 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { if ctx.GlobalIsSet(AllowUnprotectedTxs.Name) { cfg.AllowUnprotectedTxs = ctx.GlobalBool(AllowUnprotectedTxs.Name) } + + if ctx.IsSet(BatchRequestLimit.Name) { + cfg.BatchRequestLimit = ctx.Int(BatchRequestLimit.Name) + } + + if ctx.IsSet(BatchResponseMaxSize.Name) { + cfg.BatchResponseMaxSize = ctx.Int(BatchResponseMaxSize.Name) + } } // setGraphQL creates the GraphQL listener interface string from the set diff --git a/node/api.go b/node/api.go index 1b32399f635c7..3c87f1752b689 100644 --- a/node/api.go +++ b/node/api.go @@ -185,6 +185,10 @@ func (api *privateAdminAPI) StartHTTP(host *string, port *int, cors *string, api CorsAllowedOrigins: api.node.config.HTTPCors, Vhosts: api.node.config.HTTPVirtualHosts, Modules: api.node.config.HTTPModules, + rpcEndpointConfig: rpcEndpointConfig{ + batchItemLimit: api.node.config.BatchRequestLimit, + batchResponseSizeLimit: api.node.config.BatchResponseMaxSize, + }, } if cors != nil { config.CorsAllowedOrigins = nil @@ -259,6 +263,10 @@ func (api *privateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str Modules: api.node.config.WSModules, Origins: api.node.config.WSOrigins, // ExposeAll: api.node.config.WSExposeAll, + rpcEndpointConfig: rpcEndpointConfig{ + batchItemLimit: api.node.config.BatchRequestLimit, + batchResponseSizeLimit: api.node.config.BatchResponseMaxSize, + }, } if apis != nil { config.Modules = nil diff --git a/node/config.go b/node/config.go index 2047299fb5d74..e5aebc5fd7be4 100644 --- a/node/config.go +++ b/node/config.go @@ -203,6 +203,12 @@ type Config struct { // JWTSecret is the hex-encoded jwt secret. JWTSecret string `toml:",omitempty"` + + // BatchRequestLimit is the maximum number of requests in a batch. + BatchRequestLimit int `toml:",omitempty"` + + // BatchResponseMaxSize is the maximum number of bytes returned from a batched rpc call. + BatchResponseMaxSize int `toml:",omitempty"` } // IPCEndpoint resolves an IPC endpoint based on a configured value, taking into diff --git a/node/defaults.go b/node/defaults.go index fd0277e29dc93..6c947a863187c 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -48,17 +48,19 @@ var ( // DefaultConfig contains reasonable default settings. var DefaultConfig = Config{ - DataDir: DefaultDataDir(), - HTTPPort: DefaultHTTPPort, - AuthAddr: DefaultAuthHost, - AuthPort: DefaultAuthPort, - AuthVirtualHosts: DefaultAuthVhosts, - HTTPModules: []string{"net", "web3"}, - HTTPVirtualHosts: []string{"localhost"}, - HTTPTimeouts: rpc.DefaultHTTPTimeouts, - WSPort: DefaultWSPort, - WSModules: []string{"net", "web3"}, - GraphQLVirtualHosts: []string{"localhost"}, + DataDir: DefaultDataDir(), + HTTPPort: DefaultHTTPPort, + AuthAddr: DefaultAuthHost, + AuthPort: DefaultAuthPort, + AuthVirtualHosts: DefaultAuthVhosts, + HTTPModules: []string{"net", "web3"}, + HTTPVirtualHosts: []string{"localhost"}, + HTTPTimeouts: rpc.DefaultHTTPTimeouts, + WSPort: DefaultWSPort, + WSModules: []string{"net", "web3"}, + BatchRequestLimit: 1000, + BatchResponseMaxSize: 25 * 1000 * 1000, + GraphQLVirtualHosts: []string{"localhost"}, P2P: p2p.Config{ ListenAddr: ":30303", MaxPeers: 50, diff --git a/node/node.go b/node/node.go index 0a2b9eb836920..70b3ecb9bfe25 100644 --- a/node/node.go +++ b/node/node.go @@ -101,10 +101,11 @@ func New(conf *Config) (*Node, error) { if strings.HasSuffix(conf.Name, ".ipc") { return nil, errors.New(`Config.Name cannot end in ".ipc"`) } - + server := rpc.NewServer() + server.SetBatchLimits(conf.BatchRequestLimit, conf.BatchResponseMaxSize) node := &Node{ config: conf, - inprocHandler: rpc.NewServer(), + inprocHandler: server, eventmux: new(event.TypeMux), log: conf.Logger, stop: make(chan struct{}), @@ -395,7 +396,10 @@ func (n *Node) startRPC() error { servers []*httpServer open, all = n.GetAPIs() ) - + rpcConfig := rpcEndpointConfig{ + batchItemLimit: n.config.BatchRequestLimit, + batchResponseSizeLimit: n.config.BatchResponseMaxSize, + } initHttp := func(server *httpServer, apis []rpc.API, port int) error { if err := server.setListenAddr(n.config.HTTPHost, port); err != nil { return err @@ -405,6 +409,7 @@ func (n *Node) startRPC() error { Vhosts: n.config.HTTPVirtualHosts, Modules: n.config.HTTPModules, prefix: n.config.HTTPPathPrefix, + rpcEndpointConfig: rpcConfig, }); err != nil { return err } @@ -418,9 +423,10 @@ func (n *Node) startRPC() error { return err } if err := server.enableWS(n.rpcAPIs, wsConfig{ - Modules: n.config.WSModules, - Origins: n.config.WSOrigins, - prefix: n.config.WSPathPrefix, + Modules: n.config.WSModules, + Origins: n.config.WSOrigins, + rpcEndpointConfig: rpcConfig, + prefix: n.config.WSPathPrefix, }); err != nil { return err } @@ -434,26 +440,29 @@ func (n *Node) startRPC() error { if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { return err } + sharedConfig := rpcConfig + sharedConfig.jwtSecret = secret if err := server.enableRPC(apis, httpConfig{ CorsAllowedOrigins: DefaultAuthCors, Vhosts: n.config.AuthVirtualHosts, Modules: DefaultAuthModules, prefix: DefaultAuthPrefix, - jwtSecret: secret, + rpcEndpointConfig: sharedConfig, }); err != nil { return err } servers = append(servers, server) + // Enable auth via WS server = n.wsServerForPort(port, true) if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { return err } if err := server.enableWS(apis, wsConfig{ - Modules: DefaultAuthModules, - Origins: DefaultAuthOrigins, - prefix: DefaultAuthPrefix, - jwtSecret: secret, + Modules: DefaultAuthModules, + Origins: DefaultAuthOrigins, + prefix: DefaultAuthPrefix, + rpcEndpointConfig: sharedConfig, }); err != nil { return err } diff --git a/node/rpcstack.go b/node/rpcstack.go index 0d2be9008a41e..a733ba937fbac 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -40,14 +40,21 @@ type httpConfig struct { Vhosts []string prefix string // path prefix on which to mount http handler jwtSecret []byte // optional JWT secret + rpcEndpointConfig } // wsConfig is the JSON-RPC/Websocket configuration type wsConfig struct { - Origins []string - Modules []string - prefix string // path prefix on which to mount ws handler - jwtSecret []byte // optional JWT secret + Origins []string + Modules []string + prefix string // path prefix on which to mount ws handler + rpcEndpointConfig +} + +type rpcEndpointConfig struct { + jwtSecret []byte // optional JWT secret + batchItemLimit int + batchResponseSizeLimit int } type rpcHandler struct { @@ -281,6 +288,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { // Create RPC server and handler. srv := rpc.NewServer() + srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit) if err := RegisterApis(apis, config.Modules, srv, false); err != nil { return err } @@ -312,6 +320,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } // Create RPC server and handler. srv := rpc.NewServer() + srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit) if err := RegisterApis(apis, config.Modules, srv, false); err != nil { return err } diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 229a5b5e53baa..949a8c1176c47 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -314,8 +314,10 @@ func TestJWT(t *testing.T) { ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret) return ss } - srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, - true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) + cfg := rpcEndpointConfig{jwtSecret: []byte("secret")} + httpcfg := &httpConfig{rpcEndpointConfig: cfg} + wscfg := &wsConfig{Origins: []string{"*"}, rpcEndpointConfig: cfg} + srv := createAndStartServer(t, httpcfg, true, wscfg) wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) diff --git a/rpc/client.go b/rpc/client.go index d3ce0297754c9..455767bb9b8f0 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "net/url" + "os" "reflect" "strconv" "sync/atomic" @@ -32,14 +33,15 @@ import ( var ( ErrClientQuit = errors.New("client is closed") - ErrNoResult = errors.New("no result in JSON-RPC response") + ErrNoResult = errors.New("JSON-RPC response has no result") + ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") errClientReconnected = errors.New("client reconnected") errDead = errors.New("connection lost") ) +// Timeouts const ( - // Timeouts defaultDialTimeout = 10 * time.Second // used if context has no deadline subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls ) @@ -82,6 +84,10 @@ type Client struct { // This function, if non-nil, is called when the connection is lost. reconnectFunc reconnectFunc + // config fields + batchItemLimit int + batchResponseMaxSize int + // writeConn is used for writing to the connection on the caller's goroutine. It should // only be accessed outside of dispatch, with the write lock held. The write lock is // taken by sending on reqInit and released by sending on reqSent. @@ -112,7 +118,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn { ctx := context.Background() ctx = context.WithValue(ctx, clientContextKey{}, c) ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo()) - handler := newHandler(ctx, conn, c.idgen, c.services) + handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize) return &clientConn{conn, handler} } @@ -126,14 +132,17 @@ type readOp struct { batch bool } +// requestOp represents a pending request. This is used for both batch and non-batch +// requests. type requestOp struct { - ids []json.RawMessage - err error - resp chan *jsonrpcMessage // receives up to len(ids) responses - sub *ClientSubscription // only set for EthSubscribe requests + ids []json.RawMessage + err error + resp chan []*jsonrpcMessage // the response goes here + sub *ClientSubscription // set for Subscribe requests. + hadResponse bool // true when the request was responded to } -func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) { +func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) { select { case <-ctx.Done(): // Send the timeout to dispatch so it can remove the request IDs. @@ -167,23 +176,36 @@ func Dial(rawurl string) (*Client, error) { // // The context is used to cancel or time out the initial connection establishment. It does // not affect subsequent interactions with the client. -func DialContext(ctx context.Context, rawurl string) (*Client, error) { +func DialContext(ctx context.Context, rawurl string, options ...ClientOption) (*Client, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err } + + cfg := new(clientConfig) + for _, opt := range options { + opt.applyOption(cfg) + } + + var reconnect reconnectFunc switch u.Scheme { case "http", "https": - return DialHTTP(rawurl) + reconnect = newClientTransportHTTP(rawurl, cfg) case "ws", "wss": - return DialWebsocket(ctx, rawurl, "") + rc, err := newClientTransportWS(rawurl, cfg) + if err != nil { + return nil, err + } + reconnect = rc case "stdio": - return DialStdIO(ctx) + reconnect = newClientTransportIO(os.Stdin, os.Stdout) case "": - return DialIPC(ctx, rawurl) + reconnect = newClientTransportIPC(rawurl) default: return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) } + + return newClient(ctx, cfg, reconnect) } // ClientFromContext retrieves the client from the context, if any. This can be used to perform @@ -193,33 +215,42 @@ func ClientFromContext(ctx context.Context) (*Client, bool) { return client, ok } -func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { +func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) { conn, err := connect(initctx) if err != nil { return nil, err } - c := initClient(conn, randomIDGenerator(), new(serviceRegistry)) + c := initClient(conn, new(serviceRegistry), cfg) c.reconnectFunc = connect return c, nil } -func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { +func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client { _, isHTTP := conn.(*httpConn) c := &Client{ - isHTTP: isHTTP, - idgen: idgen, - services: services, - writeConn: conn, - close: make(chan struct{}), - closing: make(chan struct{}), - didClose: make(chan struct{}), - reconnected: make(chan ServerCodec), - readOp: make(chan readOp), - readErr: make(chan error), - reqInit: make(chan *requestOp), - reqSent: make(chan error, 1), - reqTimeout: make(chan *requestOp), - } + isHTTP: isHTTP, + services: services, + idgen: cfg.idgen, + batchItemLimit: cfg.batchItemLimit, + batchResponseMaxSize: cfg.batchResponseLimit, + writeConn: conn, + close: make(chan struct{}), + closing: make(chan struct{}), + didClose: make(chan struct{}), + reconnected: make(chan ServerCodec), + readOp: make(chan readOp), + readErr: make(chan error), + reqInit: make(chan *requestOp), + reqSent: make(chan error, 1), + reqTimeout: make(chan *requestOp), + } + + // Set defaults. + if c.idgen == nil { + c.idgen = randomIDGenerator() + } + + // Launch the main loop. if !isHTTP { go c.dispatch(conn) } @@ -297,7 +328,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str if err != nil { return err } - op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} + op := &requestOp{ + ids: []json.RawMessage{msg.ID}, + resp: make(chan []*jsonrpcMessage, 1), + } if c.isHTTP { err = c.sendHTTP(ctx, op, msg) @@ -309,9 +343,12 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str } // dispatch has accepted the request and will close the channel when it quits. - switch resp, err := op.wait(ctx, c); { - case err != nil: + batchresp, err := op.wait(ctx, c) + if err != nil { return err + } + resp := batchresp[0] + switch { case resp.Error != nil: return resp.Error case len(resp.Result) == 0: @@ -349,7 +386,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { ) op := &requestOp{ ids: make([]json.RawMessage, len(b)), - resp: make(chan *jsonrpcMessage, len(b)), + resp: make(chan []*jsonrpcMessage, 1), } for i, elem := range b { msg, err := c.newMessage(elem.Method, elem.Args...) @@ -367,28 +404,48 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { } else { err = c.send(ctx, op, msgs) } + if err != nil { + return err + } + + batchresp, err := op.wait(ctx, c) + if err != nil { + return err + } // Wait for all responses to come back. - for n := 0; n < len(b) && err == nil; n++ { - var resp *jsonrpcMessage - resp, err = op.wait(ctx, c) - if err != nil { - break + for n := 0; n < len(batchresp) && err == nil; n++ { + resp := batchresp[n] + if resp == nil { + // Ignore null responses. These can happen for batches sent via HTTP. + continue } + // Find the element corresponding to this response. - // The element is guaranteed to be present because dispatch - // only sends valid IDs to our channel. - elem := &b[byID[string(resp.ID)]] - if resp.Error != nil { - elem.Error = resp.Error + index, ok := byID[string(resp.ID)] + if !ok { continue } - if len(resp.Result) == 0 { + delete(byID, string(resp.ID)) + + // Assign result and error. + elem := &b[index] + switch { + case resp.Error != nil: + elem.Error = resp.Error + case resp.Result == nil: elem.Error = ErrNoResult - continue + default: + elem.Error = json.Unmarshal(resp.Result, elem.Result) } - elem.Error = json.Unmarshal(resp.Result, elem.Result) } + + // Check that all expected responses have been received. + for _, index := range byID { + elem := &b[index] + elem.Error = ErrMissingBatchResponse + } + return err } @@ -449,7 +506,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf } op := &requestOp{ ids: []json.RawMessage{msg.ID}, - resp: make(chan *jsonrpcMessage), + resp: make(chan []*jsonrpcMessage, 1), sub: newClientSubscription(c, namespace, chanVal), } @@ -499,7 +556,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { return err } } - err := c.writeConn.writeJSON(ctx, msg) + err := c.writeConn.writeJSON(ctx, msg, false) if err != nil { c.writeConn = nil if !retry { @@ -632,7 +689,7 @@ func (c *Client) read(codec ServerCodec) { for { msgs, batch, err := codec.readBatch() if _, ok := err.(*json.SyntaxError); ok { - codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()})) + codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()}), true) } if err != nil { c.readErr <- err diff --git a/rpc/client_opt.go b/rpc/client_opt.go new file mode 100644 index 0000000000000..5bef08cca8410 --- /dev/null +++ b/rpc/client_opt.go @@ -0,0 +1,135 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import ( + "net/http" + + "github.com/gorilla/websocket" +) + +// ClientOption is a configuration option for the RPC client. +type ClientOption interface { + applyOption(*clientConfig) +} + +type clientConfig struct { + // HTTP settings + httpClient *http.Client + httpHeaders http.Header + httpAuth HTTPAuth + + // WebSocket options + wsDialer *websocket.Dialer + + // RPC handler options + idgen func() ID + batchItemLimit int + batchResponseLimit int +} + +func (cfg *clientConfig) initHeaders() { + if cfg.httpHeaders == nil { + cfg.httpHeaders = make(http.Header) + } +} + +func (cfg *clientConfig) setHeader(key, value string) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) +} + +type optionFunc func(*clientConfig) + +func (fn optionFunc) applyOption(opt *clientConfig) { + fn(opt) +} + +// WithWebsocketDialer configures the websocket.Dialer used by the RPC client. +func WithWebsocketDialer(dialer websocket.Dialer) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.wsDialer = &dialer + }) +} + +// WithHeader configures HTTP headers set by the RPC client. Headers set using this option +// will be used for both HTTP and WebSocket connections. +func WithHeader(key, value string) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) + }) +} + +// WithHeaders configures HTTP headers set by the RPC client. Headers set using this +// option will be used for both HTTP and WebSocket connections. +func WithHeaders(headers http.Header) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + for k, vs := range headers { + cfg.httpHeaders[k] = vs + } + }) +} + +// WithHTTPClient configures the http.Client used by the RPC client. +func WithHTTPClient(c *http.Client) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.httpClient = c + }) +} + +// WithHTTPAuth configures HTTP request authentication. The given provider will be called +// whenever a request is made. Note that only one authentication provider can be active at +// any time. +func WithHTTPAuth(a HTTPAuth) ClientOption { + if a == nil { + panic("nil auth") + } + return optionFunc(func(cfg *clientConfig) { + cfg.httpAuth = a + }) +} + +// A HTTPAuth function is called by the client whenever a HTTP request is sent. +// The function must be safe for concurrent use. +// +// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add +// auth information to the request. +type HTTPAuth func(h http.Header) error + +// WithBatchItemLimit changes the maximum number of items allowed in batch requests. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchItemLimit(limit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchItemLimit = limit + }) +} + +// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be +// generated for batch requests. When this limit is reached, further calls in the batch +// will not be processed. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchResponseSizeLimit(sizeLimit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchResponseLimit = sizeLimit + }) +} diff --git a/rpc/client_test.go b/rpc/client_test.go index 04c847d0d626c..52fe37abb2b22 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -238,7 +238,7 @@ func testClientCancel(transport string, t *testing.T) { _, hasDeadline := ctx.Deadline() t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) // default: - // t.Logf("got expected error with %v wait time: %v", timeout, err) + // t.Logf("got expected error with %v wait time: %v", timeout, err) } cancel() } @@ -415,7 +415,8 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) { defer srv.Stop() // Create the client on the other end of the pipe. - client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) { + cfg := new(clientConfig) + client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) { return NewCodec(p2), nil }) defer client.Close() diff --git a/rpc/errors.go b/rpc/errors.go index 4c06a745fbd8b..8eb0a3751b81c 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -56,7 +56,17 @@ var ( _ Error = new(invalidParamsError) ) -const defaultErrorCode = -32000 +const ( + defaultErrorCode = -32000 + errcodeTimeout = -32002 + errcodeResponseTooLarge = -32003 +) + +const ( + errMsgTimeout = "request timed out" + errMsgResponseTooLarge = "response too large" + errMsgBatchTooLarge = "batch too large" +) type methodNotFoundError struct{ method string } @@ -101,3 +111,13 @@ type invalidParamsError struct{ message string } func (e *invalidParamsError) ErrorCode() int { return -32602 } func (e *invalidParamsError) Error() string { return e.message } + +// internalServerError is used for server errors during request processing. +type internalServerError struct { + code int + message string +} + +func (e *internalServerError) ErrorCode() int { return e.code } + +func (e *internalServerError) Error() string { return e.message } diff --git a/rpc/handler.go b/rpc/handler.go index cd95a067f3e26..32516d144c4a3 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -34,33 +34,34 @@ import ( // // The entry points for incoming messages are: // -// h.handleMsg(message) -// h.handleBatch(message) +// h.handleMsg(message) +// h.handleBatch(message) // // Outgoing calls use the requestOp struct. Register the request before sending it // on the connection: // -// op := &requestOp{ids: ...} -// h.addRequestOp(op) +// op := &requestOp{ids: ...} +// h.addRequestOp(op) // // Now send the request, then wait for the reply to be delivered through handleMsg: // -// if err := op.wait(...); err != nil { -// h.removeRequestOp(op) // timeout, etc. -// } -// +// if err := op.wait(...); err != nil { +// h.removeRequestOp(op) // timeout, etc. +// } type handler struct { - reg *serviceRegistry - unsubscribeCb *callback - idgen func() ID // subscription ID generator - respWait map[string]*requestOp // active client requests - clientSubs map[string]*ClientSubscription // active client subscriptions - callWG sync.WaitGroup // pending call goroutines - rootCtx context.Context // canceled by close() - cancelRoot func() // cancel function for rootCtx - conn jsonWriter // where responses will be sent - log log.Logger - allowSubscribe bool + reg *serviceRegistry + unsubscribeCb *callback + idgen func() ID // subscription ID generator + respWait map[string]*requestOp // active client requests + clientSubs map[string]*ClientSubscription // active client subscriptions + callWG sync.WaitGroup // pending call goroutines + rootCtx context.Context // canceled by close() + cancelRoot func() // cancel function for rootCtx + conn jsonWriter // where responses will be sent + log log.Logger + allowSubscribe bool + batchRequestLimit int + batchResponseMaxSize int subLock sync.Mutex serverSubs map[ID]*Subscription @@ -71,19 +72,21 @@ type callProc struct { notifiers []*Notifier } -func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler { +func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler { rootCtx, cancelRoot := context.WithCancel(connCtx) h := &handler{ - reg: reg, - idgen: idgen, - conn: conn, - respWait: make(map[string]*requestOp), - clientSubs: make(map[string]*ClientSubscription), - rootCtx: rootCtx, - cancelRoot: cancelRoot, - allowSubscribe: true, - serverSubs: make(map[ID]*Subscription), - log: log.Root(), + reg: reg, + idgen: idgen, + conn: conn, + respWait: make(map[string]*requestOp), + clientSubs: make(map[string]*ClientSubscription), + rootCtx: rootCtx, + cancelRoot: cancelRoot, + allowSubscribe: true, + serverSubs: make(map[ID]*Subscription), + log: log.Root(), + batchRequestLimit: batchRequestLimit, + batchResponseMaxSize: batchResponseMaxSize, } if conn.remoteAddr() != "" { h.log = h.log.New("conn", conn.remoteAddr()) @@ -92,61 +95,218 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * return h } +// batchCallBuffer manages in progress call messages and their responses during a batch +// call. Calls need to be synchronized between the processing and timeout-triggering +// goroutines. +type batchCallBuffer struct { + mutex sync.Mutex + calls []*jsonrpcMessage + resp []*jsonrpcMessage + wrote bool +} + +// nextCall returns the next unprocessed message. +func (b *batchCallBuffer) nextCall() *jsonrpcMessage { + b.mutex.Lock() + defer b.mutex.Unlock() + + if len(b.calls) == 0 { + return nil + } + // The popping happens in `pushAnswer`. The in progress call is kept + // so we can return an error for it in case of timeout. + msg := b.calls[0] + return msg +} + +// pushResponse adds the response to last call returned by nextCall. +func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) { + b.mutex.Lock() + defer b.mutex.Unlock() + + if answer != nil { + b.resp = append(b.resp, answer) + } + b.calls = b.calls[1:] +} + +// write sends the responses. +func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.doWrite(ctx, conn, false) +} + +// respondWithError sends the responses added so far. For the remaining unanswered call +// messages, it responds with the given error. +func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) { + b.mutex.Lock() + defer b.mutex.Unlock() + + for _, msg := range b.calls { + if !msg.isNotification() { + b.resp = append(b.resp, msg.errorResponse(err)) + } + } + b.doWrite(ctx, conn, true) +} + +// doWrite actually writes the response. +// This assumes b.mutex is held. +func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) { + if b.wrote { + return + } + b.wrote = true // can only write once + if len(b.resp) > 0 { + conn.writeJSON(ctx, b.resp, isErrorResponse) + } +} + // handleBatch executes all messages in a batch and returns the responses. func (h *handler) handleBatch(msgs []*jsonrpcMessage) { // Emit error response for empty batches: if len(msgs) == 0 { h.startCallProc(func(cp *callProc) { - h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) + h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}), true) + }) + return + } + // Apply limit on total number of requests. + if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit { + h.startCallProc(func(cp *callProc) { + h.respondWithBatchTooLarge(cp, msgs) }) return } - // Handle non-call messages first: + // Handle non-call messages first. + // Here we need to find the requestOp that sent the request batch. calls := make([]*jsonrpcMessage, 0, len(msgs)) - for _, msg := range msgs { - if handled := h.handleImmediate(msg); !handled { - calls = append(calls, msg) - } - } + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + calls = append(calls, msg) + }) if len(calls) == 0 { return } + // Process calls on a goroutine because they may block indefinitely: h.startCallProc(func(cp *callProc) { - answers := make([]*jsonrpcMessage, 0, len(msgs)) - for _, msg := range calls { - if answer := h.handleCallMsg(cp, msg); answer != nil { - answers = append(answers, answer) + var ( + timer *time.Timer + cancel context.CancelFunc + callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))} + ) + + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // currently-running method might not return immediately on timeout, we must wait + // for the timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + err := &internalServerError{errcodeTimeout, errMsgTimeout} + callBuffer.respondWithError(cp.ctx, h.conn, err) + }) + } + + responseBytes := 0 + for { + // No need to handle rest of calls if timed out. + if cp.ctx.Err() != nil { + break + } + msg := callBuffer.nextCall() + if msg == nil { + break + } + resp := h.handleCallMsg(cp, msg) + callBuffer.pushResponse(resp) + if resp != nil && h.batchResponseMaxSize != 0 { + responseBytes += len(resp.Result) + if responseBytes > h.batchResponseMaxSize { + err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge} + callBuffer.respondWithError(cp.ctx, h.conn, err) + break + } } } - h.addSubscriptions(cp.notifiers) - if len(answers) > 0 { - h.conn.writeJSON(cp.ctx, answers) + if timer != nil { + timer.Stop() } + + h.addSubscriptions(cp.notifiers) + callBuffer.write(cp.ctx, h.conn) for _, n := range cp.notifiers { n.activate() } }) } -// handleMsg handles a single message. -func (h *handler) handleMsg(msg *jsonrpcMessage) { - if ok := h.handleImmediate(msg); ok { - return - } - h.startCallProc(func(cp *callProc) { - answer := h.handleCallMsg(cp, msg) - h.addSubscriptions(cp.notifiers) - if answer != nil { - h.conn.writeJSON(cp.ctx, answer) - } - for _, n := range cp.notifiers { - n.activate() +func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) { + resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge}) + // Find the first call and add its "id" field to the error. + // This is the best we can do, given that the protocol doesn't have a way + // of reporting an error for the entire batch. + for _, msg := range batch { + if msg.isCall() { + resp.ID = msg.ID + break } + } + h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true) +} + +// handleMsg handles a single non-batch message. +func (h *handler) handleMsg(msg *jsonrpcMessage) { + msgs := []*jsonrpcMessage{msg} + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + h.startCallProc(func(cp *callProc) { + h.handleNonBatchCall(cp, msg) + }) }) } +func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) { + var ( + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // running method might not return immediately on timeout, we must wait for the + // timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + responded.Do(func() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + h.conn.writeJSON(cp.ctx, resp, true) + }) + }) + } + + answer := h.handleCallMsg(cp, msg) + if timer != nil { + timer.Stop() + } + h.addSubscriptions(cp.notifiers) + if answer != nil { + responded.Do(func() { + h.conn.writeJSON(cp.ctx, answer, false) + }) + } + for _, n := range cp.notifiers { + n.activate() + } +} + // close cancels all requests except for inflightReq and waits for // call goroutines to shut down. func (h *handler) close(err error, inflightReq *requestOp) { @@ -227,23 +387,60 @@ func (h *handler) startCallProc(fn func(*callProc)) { }() } -// handleImmediate executes non-call messages. It returns false if the message is a -// call or requires a reply. -func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { - start := time.Now() - switch { - case msg.isNotification(): - if strings.HasSuffix(msg.Method, notificationMethodSuffix) { - h.handleSubscriptionResult(msg) - return true +// handleResponse processes method call responses. +func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) { + var resolvedops []*requestOp + handleResp := func(msg *jsonrpcMessage) { + op := h.respWait[string(msg.ID)] + if op == nil { + h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) + return } - return false - case msg.isResponse(): - h.handleResponse(msg) - h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) - return true - default: - return false + resolvedops = append(resolvedops, op) + delete(h.respWait, string(msg.ID)) + + // For subscription responses, start the subscription if the server + // indicates success. EthSubscribe gets unblocked in either case through + // the op.resp channel. + if op.sub != nil { + if msg.Error != nil { + op.err = msg.Error + } else { + op.err = json.Unmarshal(msg.Result, &op.sub.subid) + if op.err == nil { + go op.sub.run() + h.clientSubs[op.sub.subid] = op.sub + } + } + } + + if !op.hadResponse { + op.hadResponse = true + op.resp <- batch + } + } + + for _, msg := range batch { + start := time.Now() + switch { + case msg.isResponse(): + handleResp(msg) + h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) + + case msg.isNotification(): + if strings.HasSuffix(msg.Method, notificationMethodSuffix) { + h.handleSubscriptionResult(msg) + continue + } + handleCall(msg) + + default: + handleCall(msg) + } + } + + for _, op := range resolvedops { + h.removeRequestOp(op) } } @@ -259,33 +456,6 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { } } -// handleResponse processes method call responses. -func (h *handler) handleResponse(msg *jsonrpcMessage) { - op := h.respWait[string(msg.ID)] - if op == nil { - h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) - return - } - delete(h.respWait, string(msg.ID)) - // For normal responses, just forward the reply to Call/BatchCall. - if op.sub == nil { - op.resp <- msg - return - } - // For subscription responses, start the subscription if the server - // indicates success. EthSubscribe gets unblocked in either case through - // the op.resp channel. - defer close(op.resp) - if msg.Error != nil { - op.err = msg.Error - return - } - if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil { - go op.sub.run() - h.clientSubs[op.sub.subid] = op.sub - } -} - // handleCallMsg executes a call message and returns the answer. func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { start := time.Now() @@ -294,6 +464,7 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess h.handleCall(ctx, msg) h.log.Debug("Served "+msg.Method, "duration", time.Since(start)) return nil + case msg.isCall(): resp := h.handleCall(ctx, msg) var ctx []interface{} @@ -308,8 +479,10 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess h.log.Debug("Served "+msg.Method, ctx...) } return resp + case msg.hasValidID(): return msg.errorResponse(&invalidRequestError{"invalid request"}) + default: return errorMessage(&invalidRequestError{"invalid request"}) } @@ -329,6 +502,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage if callb == nil { return msg.errorResponse(&methodNotFoundError{method: msg.Method}) } + args, err := parsePositionalArguments(msg.Params, callb.argTypes) if err != nil { return msg.errorResponse(&invalidParamsError{err.Error()}) @@ -348,6 +522,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage rpcServingTimer.UpdateSince(start) updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start)) } + return answer } diff --git a/rpc/http.go b/rpc/http.go index 9f44649573493..f57335656c307 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "math" "mime" "net/http" "net/url" @@ -51,7 +52,7 @@ type httpConn struct { // and some methods don't work. The panic() stubs here exist to ensure // this special treatment is correct. -func (hc *httpConn) writeJSON(context.Context, interface{}) error { +func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error { panic("writeJSON called on httpConn") } @@ -108,6 +109,11 @@ var DefaultHTTPTimeouts = HTTPTimeouts{ IdleTimeout: 120 * time.Second, } +// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. +func DialHTTP(endpoint string) (*Client, error) { + return DialHTTPWithClient(endpoint, new(http.Client)) +} + // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP // using the provided HTTP Client. func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { @@ -117,24 +123,35 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { return nil, err } - initctx := context.Background() - headers := make(http.Header, 2) + var cfg clientConfig + cfg.httpClient = client + fn := newClientTransportHTTP(endpoint, &cfg) + return newClient(context.Background(), &cfg, fn) +} + +func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { + headers := make(http.Header, 2+len(cfg.httpHeaders)) headers.Set("accept", contentType) headers.Set("content-type", contentType) - return newClient(initctx, func(context.Context) (ServerCodec, error) { - hc := &httpConn{ - client: client, - headers: headers, - url: endpoint, - closeCh: make(chan interface{}), - } - return hc, nil - }) -} + for key, values := range cfg.httpHeaders { + headers[key] = values + } -// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. -func DialHTTP(endpoint string) (*Client, error) { - return DialHTTPWithClient(endpoint, new(http.Client)) + client := cfg.httpClient + if client == nil { + client = new(http.Client) + } + + hc := &httpConn{ + client: client, + headers: headers, + url: endpoint, + closeCh: make(chan interface{}), + } + + return func(ctx context.Context) (ServerCodec, error) { + return hc, nil + } } func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error { @@ -145,11 +162,12 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e } defer respBody.Close() - var respmsg jsonrpcMessage - if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil { + var resp jsonrpcMessage + batch := [1]*jsonrpcMessage{&resp} + if err := json.NewDecoder(respBody).Decode(&resp); err != nil { return err } - op.resp <- &respmsg + op.resp <- batch[:] return nil } @@ -160,13 +178,12 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr return err } defer respBody.Close() - var respmsgs []jsonrpcMessage + + var respmsgs []*jsonrpcMessage if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { return err } - for i := 0; i < len(respmsgs); i++ { - op.resp <- &respmsgs[i] - } + op.resp <- respmsgs return nil } @@ -288,3 +305,35 @@ func validateRequest(r *http.Request) (int, error) { err := fmt.Errorf("invalid content type, only %s is supported", contentType) return http.StatusUnsupportedMediaType, err } + +// ContextRequestTimeout returns the request timeout derived from the given context. +func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) { + timeout := time.Duration(math.MaxInt64) + hasTimeout := false + setTimeout := func(d time.Duration) { + if d < timeout { + timeout = d + hasTimeout = true + } + } + + if deadline, ok := ctx.Deadline(); ok { + setTimeout(time.Until(deadline)) + } + + // If the context is an HTTP request context, use the server's WriteTimeout. + httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server) + if ok && httpSrv.WriteTimeout > 0 { + wt := httpSrv.WriteTimeout + // When a write timeout is configured, we need to send the response message before + // the HTTP server cuts connection. So our internal timeout must be earlier than + // the server's true timeout. + // + // Note: Timeouts are sanitized to be a minimum of 1 second. + // Also see issue: https://github.com/golang/go/issues/47229 + wt -= 100 * time.Millisecond + setTimeout(wt) + } + + return timeout, hasTimeout +} diff --git a/rpc/inproc.go b/rpc/inproc.go index fbe9a40ceca9f..306974e04b81f 100644 --- a/rpc/inproc.go +++ b/rpc/inproc.go @@ -24,7 +24,8 @@ import ( // DialInProc attaches an in-process connection to the given RPC server. func DialInProc(handler *Server) *Client { initctx := context.Background() - c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { + cfg := new(clientConfig) + c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) { p1, p2 := net.Pipe() go handler.ServeCodec(NewCodec(p1), 0) return NewCodec(p2), nil diff --git a/rpc/ipc.go b/rpc/ipc.go index 07a211c6277c4..a08245b270891 100644 --- a/rpc/ipc.go +++ b/rpc/ipc.go @@ -46,11 +46,16 @@ func (s *Server) ServeListener(l net.Listener) error { // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialIPC(ctx context.Context, endpoint string) (*Client, error) { - return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { + cfg := new(clientConfig) + return newClient(ctx, cfg, newClientTransportIPC(endpoint)) +} + +func newClientTransportIPC(endpoint string) reconnectFunc { + return func(ctx context.Context) (ServerCodec, error) { conn, err := newIPCConnection(ctx, endpoint) if err != nil { return nil, err } return NewCodec(conn), err - }) + } } diff --git a/rpc/json.go b/rpc/json.go index 6024f1e7dc9bf..54c618ebd68eb 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -165,18 +165,22 @@ type ConnRemoteAddr interface { // support for parsing arguments and serializing (result) objects. type jsonCodec struct { remote string - closer sync.Once // close closed channel once - closeCh chan interface{} // closed on Close - decode func(v interface{}) error // decoder to allow multiple transports - encMu sync.Mutex // guards the encoder - encode func(v interface{}) error // encoder to allow multiple transports + closer sync.Once // close closed channel once + closeCh chan interface{} // closed on Close + decode decodeFunc // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode encodeFunc // encoder to allow multiple transports conn deadlineCloser } +type encodeFunc = func(v interface{}, isErrorResponse bool) error + +type decodeFunc = func(v interface{}) error + // NewFuncCodec creates a codec which uses the given functions to read and write. If conn // implements ConnRemoteAddr, log messages will use it to include the remote address of // the connection. -func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { +func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec { codec := &jsonCodec{ closeCh: make(chan interface{}), encode: encode, @@ -195,7 +199,10 @@ func NewCodec(conn Conn) ServerCodec { enc := json.NewEncoder(conn) dec := json.NewDecoder(conn) dec.UseNumber() - return NewFuncCodec(conn, enc.Encode, dec.Decode) + encode := func(v interface{}, isErrorResponse bool) error { + return enc.Encode(v) + } + return NewFuncCodec(conn, encode, dec.Decode) } func (c *jsonCodec) peerInfo() PeerInfo { @@ -225,7 +232,7 @@ func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err err return messages, batch, nil } -func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { +func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error { c.encMu.Lock() defer c.encMu.Unlock() @@ -234,7 +241,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { deadline = time.Now().Add(defaultWriteTimeout) } c.conn.SetWriteDeadline(deadline) - return c.encode(v) + return c.encode(v, isErrorResponse) } func (c *jsonCodec) close() { diff --git a/rpc/server.go b/rpc/server.go index babc5688e2648..745728eb5d1f5 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -43,10 +43,12 @@ const ( // Server is an RPC server. type Server struct { - services serviceRegistry - idgen func() ID - run int32 - codecs mapset.Set + services serviceRegistry + idgen func() ID + run int32 + codecs mapset.Set + batchItemLimit int + batchResponseLimit int } // NewServer creates a new server instance with no registered handlers. @@ -59,6 +61,17 @@ func NewServer() *Server { return server } +// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit' +// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of +// response bytes across all requests in a batch. +// +// This method should be called before processing any requests via ServeCodec, ServeHTTP, +// ServeListener etc. +func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) { + s.batchItemLimit = itemLimit + s.batchResponseLimit = maxResponseSize +} + // RegisterName creates a service for the given receiver type under the given name. When no // methods on the given receiver match the criteria to be either a RPC method or a // subscription an error is returned. Otherwise a new service is created and added to the @@ -84,7 +97,12 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { s.codecs.Add(codec) defer s.codecs.Remove(codec) - c := initClient(codec, s.idgen, &s.services) + cfg := &clientConfig{ + idgen: s.idgen, + batchItemLimit: s.batchItemLimit, + batchResponseLimit: s.batchResponseLimit, + } + c := initClient(codec, &s.services, cfg) <-codec.closed() c.Close() } @@ -98,14 +116,14 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { return } - h := newHandler(ctx, codec, s.idgen, &s.services) + h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit) h.allowSubscribe = false defer h.close(io.EOF, nil) reqs, batch, err := codec.readBatch() if err != nil { if err != io.EOF { - codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"})) + codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"}), true) } return } diff --git a/rpc/server_test.go b/rpc/server_test.go index d09d31634beeb..bad224d66283b 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -70,6 +70,7 @@ func TestServer(t *testing.T) { func runTestScript(t *testing.T, file string) { server := newTestServer() + server.SetBatchLimits(4, 100000) content, err := os.ReadFile(file) if err != nil { t.Fatal(err) @@ -152,3 +153,41 @@ func TestServerShortLivedConn(t *testing.T) { } } } + +func TestServerBatchResponseSizeLimit(t *testing.T) { + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(100, 60) + var ( + batch []BatchElem + client = DialInProc(server) + ) + for i := 0; i < 5; i++ { + batch = append(batch, BatchElem{ + Method: "test_echo", + Args: []any{"x", 1}, + Result: new(echoResult), + }) + } + if err := client.BatchCall(batch); err != nil { + t.Fatal("error sending batch:", err) + } + for i := range batch { + // We expect the first two queries to be ok, but after that the size limit takes effect. + if i < 2 { + if batch[i].Error != nil { + t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error) + } + continue + } + // After two, we expect an error. + re, ok := batch[i].Error.(Error) + if !ok { + t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error) + } + wantedCode := errcodeResponseTooLarge + if re.ErrorCode() != wantedCode { + t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode) + } + } +} diff --git a/rpc/stdio.go b/rpc/stdio.go index be2bab1c98bd1..084e5f0700ced 100644 --- a/rpc/stdio.go +++ b/rpc/stdio.go @@ -32,12 +32,17 @@ func DialStdIO(ctx context.Context) (*Client, error) { // DialIO creates a client which uses the given IO channels func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { - return newClient(ctx, func(_ context.Context) (ServerCodec, error) { + cfg := new(clientConfig) + return newClient(ctx, cfg, newClientTransportIO(in, out)) +} + +func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc { + return func(context.Context) (ServerCodec, error) { return NewCodec(stdioConn{ in: in, out: out, }), nil - }) + } } type stdioConn struct { diff --git a/rpc/subscription.go b/rpc/subscription.go index d7ba784fc532d..569546d67d8e8 100644 --- a/rpc/subscription.go +++ b/rpc/subscription.go @@ -179,7 +179,7 @@ func (n *Notifier) send(sub *Subscription, data json.RawMessage) error { Version: vsn, Method: n.namespace + notificationMethodSuffix, Params: params, - }) + }, false) } // A Subscription is created by a notifier and tied to that notifier. The client can use diff --git a/rpc/testdata/invalid-batch-toolarge.js b/rpc/testdata/invalid-batch-toolarge.js new file mode 100644 index 0000000000000..218fea58aaac2 --- /dev/null +++ b/rpc/testdata/invalid-batch-toolarge.js @@ -0,0 +1,13 @@ +// This file checks the behavior of the batch item limit code. +// In tests, the batch item limit is set to 4. So to trigger the error, +// all batches in this file have 5 elements. + +// For batches that do not contain any calls, a response message with "id" == null +// is returned. + +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}] + +// For batches with at least one call, the call's "id" is used. +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}] diff --git a/rpc/types.go b/rpc/types.go index f4d05be48cd47..76c0862caa18a 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -51,7 +51,7 @@ type ServerCodec interface { // jsonWriter can write JSON messages to its underlying connection. // Implementations must be safe for concurrent use. type jsonWriter interface { - writeJSON(context.Context, interface{}) error + writeJSON(context.Context, interface{}, bool) error // Closed returns a channel which is closed when the connection is closed. closed() <-chan interface{} // RemoteAddr returns the peer address of the connection. diff --git a/rpc/websocket.go b/rpc/websocket.go index 28380d8aa4ae0..f0b8839483637 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -184,21 +184,16 @@ func parseOriginURL(origin string) (string, string, string, error) { // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint using the provided dialer. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { - endpoint, header, err := wsClientHeaders(endpoint, origin) + cfg := new(clientConfig) + cfg.wsDialer = &dialer + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) if err != nil { return nil, err } - return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { - conn, resp, err := dialer.DialContext(ctx, endpoint, header) - if err != nil { - hErr := wsHandshakeError{err: err} - if resp != nil { - hErr.status = resp.Status - } - return nil, hErr - } - return newWebsocketCodec(conn, endpoint, header), nil - }) + return newClient(ctx, cfg, connect) } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server @@ -207,12 +202,54 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { - dialer := websocket.Dialer{ - ReadBufferSize: wsReadBuffer, - WriteBufferSize: wsWriteBuffer, - WriteBufferPool: wsBufferPool, + cfg := new(clientConfig) + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) + if err != nil { + return nil, err + } + return newClient(ctx, cfg, connect) +} + +func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { + dialer := cfg.wsDialer + if dialer == nil { + dialer = &websocket.Dialer{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + Proxy: http.ProxyFromEnvironment, + } + } + + dialURL, header, err := wsClientHeaders(endpoint, "") + if err != nil { + return nil, err + } + for key, values := range cfg.httpHeaders { + header[key] = values } - return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) + + connect := func(ctx context.Context) (ServerCodec, error) { + header := header.Clone() + if cfg.httpAuth != nil { + if err := cfg.httpAuth(header); err != nil { + return nil, err + } + } + conn, resp, err := dialer.DialContext(ctx, dialURL, header) + if err != nil { + hErr := wsHandshakeError{err: err} + if resp != nil { + hErr.status = resp.Status + } + return nil, hErr + } + return newWebsocketCodec(conn, dialURL, header), nil + } + return connect, nil } func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { @@ -247,8 +284,11 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve conn.SetReadDeadline(time.Time{}) return nil }) + encode := func(v interface{}, isErrorResponse bool) error { + return conn.WriteJSON(v) + } wc := &websocketCodec{ - jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), + jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), conn: conn, pingReset: make(chan struct{}, 1), info: PeerInfo{ @@ -275,8 +315,8 @@ func (wc *websocketCodec) peerInfo() PeerInfo { return wc.info } -func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { - err := wc.jsonCodec.writeJSON(ctx, v) +func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { + err := wc.jsonCodec.writeJSON(ctx, v, isError) if err == nil { // Notify pingLoop to delay the next idle ping. select {