Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of cached values in Rack::Request. #2054

Merged
merged 4 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
154 changes: 110 additions & 44 deletions lib/rack/request.rb
Original file line number Diff line number Diff line change
Expand Up @@ -480,25 +480,114 @@ def parseable_data?
PARSEABLE_DATA_MEDIA_TYPES.include?(media_type)
end

# Returns the data received in the query string.
def GET
if get_header(RACK_REQUEST_QUERY_STRING) == query_string
if query_hash = get_header(RACK_REQUEST_QUERY_HASH)
return query_hash
# Given a current input value, and a validity key, check if the cache
# is valid, and if so, return the cached value. If not, yield the
# current value to the block, and set the cache to the result.
#
# This method does not use cache_key, so it is shared between all
# instance of Rack::Request and it's sub-classes.
private def cache_for(key, validity_key, current_value)
# Get the current value of the validity key and compare it with the input value:
if get_header(validity_key).equal?(current_value)
# If the values are the same, then the cache is valid, so return the cached value.
if has_header?(key)
value = get_header(key)
# If the cached value is an exception, then re-raise it.
if value.is_a?(Exception)
raise value.class, value.message, cause: value.cause
else
# Otherwise, return the cached value.
return value
end
end
end

set_header(RACK_REQUEST_QUERY_HASH, expand_params(query_param_list))
# If the cache is not valid, then yield the current value to the block:
value = yield(current_value)

# Set the validity key to the current value so that we can detect changes:
set_header(validity_key, current_value)

# Set the cache to the result of the block, and return the result:
set_header(key, value)
rescue => error
# If an exception is raised, then set the cache to the exception, and re-raise it:
set_header(validity_key, current_value)
set_header(key, error)
raise
end

# This cache key is used by cached values generated by class_cache_for,
# specfically GET and POST. This is to ensure that the cache is not
# shared between instances of different classes which have different
# behaviour. This includes sub-classes that override query_parser or
# expand_params.
def cache_key
query_parser.class
end

# Given a current input value, and a validity key, check if the cache
# is valid, and if so, return the cached value. If not, yield the
# current value to the block, and set the cache to the result.
#
# This method uses cache_key to ensure that the cache is not shared
# between instances of different classes which have different
# behaviour of the cached operations.
private def class_cache_for(key, validity_key, current_value)
# The cache is organised in the env as:
# env[key][cache_key] = value
# and is valid as long as env[validity_key].equal?(current_value)

cache_key = self.cache_key

# Get the current value of the validity key and compare it with the input value:
if get_header(validity_key).equal?(current_value)
# Lookup the cache for the current cache key:
if cache = get_header(key)
if cache.key?(cache_key)
# If the cache is valid, then return the cached value.
value = cache[cache_key]
if value.is_a?(Exception)
# If the cached value is an exception, then re-raise it.
raise value.class, value.message, cause: value.cause
else
# Otherwise, return the cached value.
return value
end
end
end
end

# If the cache was not defined for this cache key, then create a new cache:
unless cache
set_header(key, cache = {})
end

begin
# Yield the current value to the block to generate an updated value:
value = yield(current_value)

# Only set this after generating the value, so that if an error or other cache depending on the same key, it will be invalidated correctly:
set_header(validity_key, current_value)
return cache[cache_key] = value
rescue => error
set_header(validity_key, current_value)
cache[cache_key] = error
ioquatix marked this conversation as resolved.
Show resolved Hide resolved
raise
end
end

# Returns the data received in the query string.
def GET
class_cache_for(RACK_REQUEST_QUERY_HASH, RACK_REQUEST_QUERY_STRING, query_string) do
expand_params(query_param_list)
end
end

def query_param_list
if get_header(RACK_REQUEST_QUERY_STRING) == query_string
get_header(RACK_REQUEST_QUERY_PAIRS)
else
query_pairs = split_query(query_string, '&')
set_header RACK_REQUEST_QUERY_STRING, query_string
set_header RACK_REQUEST_QUERY_HASH, nil
set_header(RACK_REQUEST_QUERY_PAIRS, query_pairs)
cache_for(RACK_REQUEST_QUERY_PAIRS, RACK_REQUEST_QUERY_STRING, query_string) do
set_header(RACK_REQUEST_QUERY_HASH, nil)
split_query(query_string, '&')
end
end

Expand All @@ -507,33 +596,13 @@ def query_param_list
# This method support both application/x-www-form-urlencoded and
# multipart/form-data.
def POST
if get_header(RACK_REQUEST_FORM_INPUT).equal?(get_header(RACK_INPUT))
if form_hash = get_header(RACK_REQUEST_FORM_HASH)
return form_hash
end
class_cache_for(RACK_REQUEST_FORM_HASH, RACK_REQUEST_FORM_INPUT, get_header(RACK_INPUT)) do
expand_params(body_param_list)
end

set_header(RACK_REQUEST_FORM_HASH, expand_params(body_param_list))
end

def body_param_list
if error = get_header(RACK_REQUEST_FORM_ERROR)
raise error.class, error.message, cause: error.cause
end

begin
rack_input = get_header(RACK_INPUT)

form_pairs = nil

# If the form data has already been memoized from the same
# input:
if get_header(RACK_REQUEST_FORM_INPUT).equal?(rack_input)
if form_pairs = get_header(RACK_REQUEST_FORM_PAIRS)
return form_pairs
end
end

cache_for(RACK_REQUEST_FORM_PAIRS, RACK_REQUEST_FORM_INPUT, get_header(RACK_INPUT)) do |rack_input|
if rack_input.nil?
form_pairs = []
elsif form_data? || parseable_data?
Expand All @@ -544,19 +613,16 @@ def body_param_list
# form_vars.sub!(/\0\z/, '') # performance replacement:
form_vars.slice!(-1) if form_vars.end_with?("\0")

set_header RACK_REQUEST_FORM_VARS, form_vars
# Removing this line breaks Rail test "test_filters_rack_request_form_vars"!
set_header(RACK_REQUEST_FORM_VARS, form_vars)

form_pairs = split_query(form_vars, '&')
end
else
form_pairs = []
end

set_header RACK_REQUEST_FORM_INPUT, rack_input
set_header RACK_REQUEST_FORM_HASH, nil
set_header(RACK_REQUEST_FORM_PAIRS, form_pairs)
rescue => error
set_header(RACK_REQUEST_FORM_ERROR, error)
raise

form_pairs
end
end

Expand Down
66 changes: 61 additions & 5 deletions test/spec_request.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1554,12 +1554,19 @@ def initialize(*)
rack_input.write(input)
rack_input.rewind

req = make_request Rack::MockRequest.env_for("/",
"rack.request.form_hash" => { 'foo' => 'bar' },
"rack.request.form_input" => rack_input,
:input => rack_input)
form_hash_cache = {}

req = make_request Rack::MockRequest.env_for(
"/",
"rack.request.form_hash" => form_hash_cache,
"rack.request.form_input" => rack_input,
:input => rack_input
)

req.POST.must_equal req.env['rack.request.form_hash']
form_hash = {'foo' => 'bar'}.freeze
form_hash_cache[req.cache_key] = form_hash

req.POST.must_equal form_hash
end

it "conform to the Rack spec" do
Expand Down Expand Up @@ -1957,4 +1964,53 @@ def make_request(env)
DelegateRequest.new super(env)
end
end

class UpperRequest < Rack::Request
def expand_params(parameters)
parameters.map do |(key, value)|
[key.upcase, value]
end.to_h
end

# If this is not specified, the behaviour becomes order dependent.
def cache_key
:my_request
end
end

it "correctly expands parameters" do
env = {"QUERY_STRING" => "foo=bar"}

request = Rack::Request.new(env)
request.query_param_list.must_equal [["foo", "bar"]]
request.GET.must_equal "foo" => "bar"

upper_request = UpperRequest.new(env)
upper_request.query_param_list.must_equal [["foo", "bar"]]
upper_request.GET.must_equal "FOO" => "bar"

env['QUERY_STRING'] = "foo=bar&bar=baz"

request.GET.must_equal "foo" => "bar", "bar" => "baz"
upper_request.GET.must_equal "FOO" => "bar", "BAR" => "baz"
end

class BrokenRequest < Rack::Request
def expand_params(parameters)
raise "boom"
end
end

it "raises an error if expand_params raises an error" do
env = {"QUERY_STRING" => "foo=bar"}

request = Rack::Request.new(env)
request.GET.must_equal "foo" => "bar"

broken_request = BrokenRequest.new(env)
lambda { broken_request.GET }.must_raise RuntimeError

# Subsequnt calls also raise an error:
lambda { broken_request.GET }.must_raise RuntimeError
end
end