Skip to content

Commit

Permalink
Fix handling of cached values in Rack::Request. (rack#2054)
Browse files Browse the repository at this point in the history
* Per-class cache keys for cached query/body parameters.

* Use the query parser class as the default cache key.
  • Loading branch information
ioquatix committed Mar 15, 2023
1 parent 6c6b07b commit 9d7aa4f
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 49 deletions.
154 changes: 110 additions & 44 deletions lib/rack/request.rb
Original file line number Diff line number Diff line change
Expand Up @@ -480,25 +480,114 @@ def parseable_data?
PARSEABLE_DATA_MEDIA_TYPES.include?(media_type)
end

# Returns the data received in the query string.
def GET
if get_header(RACK_REQUEST_QUERY_STRING) == query_string
if query_hash = get_header(RACK_REQUEST_QUERY_HASH)
return query_hash
# Given a current input value, and a validity key, check if the cache
# is valid, and if so, return the cached value. If not, yield the
# current value to the block, and set the cache to the result.
#
# This method does not use cache_key, so it is shared between all
# instance of Rack::Request and it's sub-classes.
private def cache_for(key, validity_key, current_value)
# Get the current value of the validity key and compare it with the input value:
if get_header(validity_key).equal?(current_value)
# If the values are the same, then the cache is valid, so return the cached value.
if has_header?(key)
value = get_header(key)
# If the cached value is an exception, then re-raise it.
if value.is_a?(Exception)
raise value.class, value.message, cause: value.cause
else
# Otherwise, return the cached value.
return value
end
end
end

set_header(RACK_REQUEST_QUERY_HASH, expand_params(query_param_list))
# If the cache is not valid, then yield the current value to the block:
value = yield(current_value)

# Set the validity key to the current value so that we can detect changes:
set_header(validity_key, current_value)

# Set the cache to the result of the block, and return the result:
set_header(key, value)
rescue => error
# If an exception is raised, then set the cache to the exception, and re-raise it:
set_header(validity_key, current_value)
set_header(key, error)
raise
end

# This cache key is used by cached values generated by class_cache_for,
# specfically GET and POST. This is to ensure that the cache is not
# shared between instances of different classes which have different
# behaviour. This includes sub-classes that override query_parser or
# expand_params.
def cache_key
query_parser.class
end

# Given a current input value, and a validity key, check if the cache
# is valid, and if so, return the cached value. If not, yield the
# current value to the block, and set the cache to the result.
#
# This method uses cache_key to ensure that the cache is not shared
# between instances of different classes which have different
# behaviour of the cached operations.
private def class_cache_for(key, validity_key, current_value)
# The cache is organised in the env as:
# env[key][cache_key] = value
# and is valid as long as env[validity_key].equal?(current_value)

cache_key = self.cache_key

# Get the current value of the validity key and compare it with the input value:
if get_header(validity_key).equal?(current_value)
# Lookup the cache for the current cache key:
if cache = get_header(key)
if cache.key?(cache_key)
# If the cache is valid, then return the cached value.
value = cache[cache_key]
if value.is_a?(Exception)
# If the cached value is an exception, then re-raise it.
raise value.class, value.message, cause: value.cause
else
# Otherwise, return the cached value.
return value
end
end
end
end

# If the cache was not defined for this cache key, then create a new cache:
unless cache
set_header(key, cache = {})
end

begin
# Yield the current value to the block to generate an updated value:
value = yield(current_value)

# Only set this after generating the value, so that if an error or other cache depending on the same key, it will be invalidated correctly:
set_header(validity_key, current_value)
return cache[cache_key] = value
rescue => error
set_header(validity_key, current_value)
cache[cache_key] = error
raise
end
end

# Returns the data received in the query string.
def GET
class_cache_for(RACK_REQUEST_QUERY_HASH, RACK_REQUEST_QUERY_STRING, query_string) do
expand_params(query_param_list)
end
end

def query_param_list
if get_header(RACK_REQUEST_QUERY_STRING) == query_string
get_header(RACK_REQUEST_QUERY_PAIRS)
else
query_pairs = split_query(query_string, '&')
set_header RACK_REQUEST_QUERY_STRING, query_string
set_header RACK_REQUEST_QUERY_HASH, nil
set_header(RACK_REQUEST_QUERY_PAIRS, query_pairs)
cache_for(RACK_REQUEST_QUERY_PAIRS, RACK_REQUEST_QUERY_STRING, query_string) do
set_header(RACK_REQUEST_QUERY_HASH, nil)
split_query(query_string, '&')
end
end

Expand All @@ -507,33 +596,13 @@ def query_param_list
# This method support both application/x-www-form-urlencoded and
# multipart/form-data.
def POST
if get_header(RACK_REQUEST_FORM_INPUT).equal?(get_header(RACK_INPUT))
if form_hash = get_header(RACK_REQUEST_FORM_HASH)
return form_hash
end
class_cache_for(RACK_REQUEST_FORM_HASH, RACK_REQUEST_FORM_INPUT, get_header(RACK_INPUT)) do
expand_params(body_param_list)
end

set_header(RACK_REQUEST_FORM_HASH, expand_params(body_param_list))
end

def body_param_list
if error = get_header(RACK_REQUEST_FORM_ERROR)
raise error.class, error.message, cause: error.cause
end

begin
rack_input = get_header(RACK_INPUT)

form_pairs = nil

# If the form data has already been memoized from the same
# input:
if get_header(RACK_REQUEST_FORM_INPUT).equal?(rack_input)
if form_pairs = get_header(RACK_REQUEST_FORM_PAIRS)
return form_pairs
end
end

cache_for(RACK_REQUEST_FORM_PAIRS, RACK_REQUEST_FORM_INPUT, get_header(RACK_INPUT)) do |rack_input|
if rack_input.nil?
form_pairs = []
elsif form_data? || parseable_data?
Expand All @@ -544,19 +613,16 @@ def body_param_list
# form_vars.sub!(/\0\z/, '') # performance replacement:
form_vars.slice!(-1) if form_vars.end_with?("\0")

set_header RACK_REQUEST_FORM_VARS, form_vars
# Removing this line breaks Rail test "test_filters_rack_request_form_vars"!
set_header(RACK_REQUEST_FORM_VARS, form_vars)

form_pairs = split_query(form_vars, '&')
end
else
form_pairs = []
end

set_header RACK_REQUEST_FORM_INPUT, rack_input
set_header RACK_REQUEST_FORM_HASH, nil
set_header(RACK_REQUEST_FORM_PAIRS, form_pairs)
rescue => error
set_header(RACK_REQUEST_FORM_ERROR, error)
raise

form_pairs
end
end

Expand Down
66 changes: 61 additions & 5 deletions test/spec_request.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1554,12 +1554,19 @@ def initialize(*)
rack_input.write(input)
rack_input.rewind

req = make_request Rack::MockRequest.env_for("/",
"rack.request.form_hash" => { 'foo' => 'bar' },
"rack.request.form_input" => rack_input,
:input => rack_input)
form_hash_cache = {}

req = make_request Rack::MockRequest.env_for(
"/",
"rack.request.form_hash" => form_hash_cache,
"rack.request.form_input" => rack_input,
:input => rack_input
)

req.POST.must_equal req.env['rack.request.form_hash']
form_hash = {'foo' => 'bar'}.freeze
form_hash_cache[req.cache_key] = form_hash

req.POST.must_equal form_hash
end

it "conform to the Rack spec" do
Expand Down Expand Up @@ -1957,4 +1964,53 @@ def make_request(env)
DelegateRequest.new super(env)
end
end

class UpperRequest < Rack::Request
def expand_params(parameters)
parameters.map do |(key, value)|
[key.upcase, value]
end.to_h
end

# If this is not specified, the behaviour becomes order dependent.
def cache_key
:my_request
end
end

it "correctly expands parameters" do
env = {"QUERY_STRING" => "foo=bar"}

request = Rack::Request.new(env)
request.query_param_list.must_equal [["foo", "bar"]]
request.GET.must_equal "foo" => "bar"

upper_request = UpperRequest.new(env)
upper_request.query_param_list.must_equal [["foo", "bar"]]
upper_request.GET.must_equal "FOO" => "bar"

env['QUERY_STRING'] = "foo=bar&bar=baz"

request.GET.must_equal "foo" => "bar", "bar" => "baz"
upper_request.GET.must_equal "FOO" => "bar", "BAR" => "baz"
end

class BrokenRequest < Rack::Request
def expand_params(parameters)
raise "boom"
end
end

it "raises an error if expand_params raises an error" do
env = {"QUERY_STRING" => "foo=bar"}

request = Rack::Request.new(env)
request.GET.must_equal "foo" => "bar"

broken_request = BrokenRequest.new(env)
lambda { broken_request.GET }.must_raise RuntimeError

# Subsequnt calls also raise an error:
lambda { broken_request.GET }.must_raise RuntimeError
end
end

0 comments on commit 9d7aa4f

Please sign in to comment.