Skip to content

Commit

Permalink
ARROW-17966: [C++] Adjust to new format for Substrait optional argume…
Browse files Browse the repository at this point in the history
…nts (apache#15)

* ARROW-17966: Updated to latest Substrait version.  Switched from optional enum args to proper options.  Added check for minimum Substrait version

* ARROW-17966: Add version to python substrait examples.  Fix version handling to check major version and not just minor

* ARROW-17966: Update cpp/src/arrow/engine/substrait/extension_set.cc

Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com>

* ARROW-17966: Update cpp/src/arrow/engine/substrait/extension_set.cc

Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com>

* ARROW-17966: Update cpp/src/arrow/engine/substrait/extension_set.cc

Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com>

* ARROW-17966: Update cpp/src/arrow/engine/substrait/extension_set.cc

Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com>

* ARROW-17966: Display the available choices when a user enters a valid substrait option that Acero doesn't support

* ARROW-17966: Simplify parsing boilerplate per review comments

* ARROW-17966: Gracefully error if the user does not supply any preferences for an option

* ARROW-17966: Prefer range loops where possible

* ARROW-17966: Rebase cleanup

* ARROW-17966: Minor fix to failing unit tests: remove enum="unspecified"

* ARROW-17966: Minor lint fix

* ARROW-17966: Cmake format

Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com>
  • Loading branch information
westonpace and bkietz committed Oct 20, 2022
1 parent efd5ae0 commit 73efef8
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 126 deletions.
7 changes: 7 additions & 0 deletions cpp/cmake_modules/ThirdpartyToolchain.cmake
Expand Up @@ -666,6 +666,13 @@ else()
endif()
endif()

# Remove these two lines once https://github.com/substrait-io/substrait/pull/342 merges
set(ENV{ARROW_SUBSTRAIT_URL}
"https://github.com/substrait-io/substrait/archive/e59008b6b202f8af06c2266991161b1e45cb056a.tar.gz"
)
set(ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM
"f64629cb377fcc62c9d3e8fe69fa6a4cf326f34d756e03db84843c5cce8d04cd")

if(DEFINED ENV{ARROW_SUBSTRAIT_URL})
set(SUBSTRAIT_SOURCE_URL "$ENV{ARROW_SUBSTRAIT_URL}")
else()
Expand Down
37 changes: 23 additions & 14 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Expand Up @@ -52,18 +52,15 @@ Id NormalizeFunctionName(Id id) {

} // namespace

Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
SubstraitCall* call, const ExtensionSet& ext_set,
Status DecodeArg(const substrait::FunctionArgument& arg, int idx, SubstraitCall* call,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
if (arg.has_enum_()) {
const substrait::FunctionArgument::Enum& enum_val = arg.enum_();
switch (enum_val.enum_kind_case()) {
case substrait::FunctionArgument::Enum::EnumKindCase::kSpecified:
call->SetEnumArg(idx, enum_val.specified());
break;
case substrait::FunctionArgument::Enum::EnumKindCase::kUnspecified:
call->SetEnumArg(idx, std::nullopt);
break;
default:
return Status::Invalid("Unrecognized enum kind case: ",
enum_val.enum_kind_case());
Expand All @@ -80,15 +77,31 @@ Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
return Status::OK();
}

Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) {
std::vector<std::string_view> prefs;
if (opt.preference_size() == 0) {
return Status::Invalid("Invalid Substrait plan. The option ", opt.name(),
" is specified but does not list any choices");
}
for (const auto& preference : opt.preference()) {
prefs.push_back(preference);
}
call->SetOption(opt.name(), prefs);
return Status::OK();
}

Result<SubstraitCall> DecodeScalarFunction(
Id id, const substrait::Expression::ScalarFunction& scalar_fn,
const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
FromProto(scalar_fn.output_type(), ext_set, conversion_options));
SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second);
for (int i = 0; i < scalar_fn.arguments_size(); i++) {
ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast<uint32_t>(i), &call,
ext_set, conversion_options));
ARROW_RETURN_NOT_OK(
DecodeArg(scalar_fn.arguments(i), i, &call, ext_set, conversion_options));
}
for (const auto& opt : scalar_fn.options()) {
ARROW_RETURN_NOT_OK(DecodeOption(opt, &call));
}
return std::move(call);
}
Expand Down Expand Up @@ -926,16 +939,12 @@ Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCa
ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options));
scalar_fn->set_allocated_output_type(output_type.release());

for (uint32_t i = 0; i < call.size(); i++) {
for (int i = 0; i < call.size(); i++) {
substrait::FunctionArgument* arg = scalar_fn->add_arguments();
if (call.HasEnumArg(i)) {
auto enum_val = std::make_unique<substrait::FunctionArgument::Enum>();
ARROW_ASSIGN_OR_RAISE(std::optional<std::string_view> enum_arg, call.GetEnumArg(i));
if (enum_arg) {
enum_val->set_specified(std::string(*enum_arg));
} else {
enum_val->set_allocated_unspecified(new google::protobuf::Empty());
}
ARROW_ASSIGN_OR_RAISE(std::string_view enum_arg, call.GetEnumArg(i));
enum_val->set_specified(std::string(enum_arg));
arg->set_allocated_enum_(enum_val.release());
} else if (call.HasValueArg(i)) {
ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i));
Expand Down
145 changes: 102 additions & 43 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Expand Up @@ -25,6 +25,7 @@
#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/util/hash_util.h"
#include "arrow/util/hashing.h"
#include "arrow/util/string.h"

namespace arrow {
namespace engine {
Expand Down Expand Up @@ -121,7 +122,7 @@ class IdStorageImpl : public IdStorage {

std::unique_ptr<IdStorage> IdStorage::Make() { return std::make_unique<IdStorageImpl>(); }

Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index) const {
Result<std::string_view> SubstraitCall::GetEnumArg(int index) const {
if (index >= size_) {
return Status::Invalid("Expected Substrait call to have an enum argument at index ",
index, " but it did not have enough arguments");
Expand All @@ -134,16 +135,16 @@ Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index
return enum_arg_it->second;
}

bool SubstraitCall::HasEnumArg(uint32_t index) const {
bool SubstraitCall::HasEnumArg(int index) const {
return enum_args_.find(index) != enum_args_.end();
}

void SubstraitCall::SetEnumArg(uint32_t index, std::optional<std::string> enum_arg) {
void SubstraitCall::SetEnumArg(int index, std::string enum_arg) {
size_ = std::max(size_, index + 1);
enum_args_[index] = std::move(enum_arg);
}

Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
Result<compute::Expression> SubstraitCall::GetValueArg(int index) const {
if (index >= size_) {
return Status::Invalid("Expected Substrait call to have a value argument at index ",
index, " but it did not have enough arguments");
Expand All @@ -156,15 +157,32 @@ Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
return value_arg_it->second;
}

bool SubstraitCall::HasValueArg(uint32_t index) const {
bool SubstraitCall::HasValueArg(int index) const {
return value_args_.find(index) != value_args_.end();
}

void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) {
void SubstraitCall::SetValueArg(int index, compute::Expression value_arg) {
size_ = std::max(size_, index + 1);
value_args_[index] = std::move(value_arg);
}

std::optional<std::vector<std::string> const*> SubstraitCall::GetOption(
std::string_view option_name) const {
auto opt = options_.find(std::string(option_name));
if (opt == options_.end()) {
return std::nullopt;
}
return &opt->second;
}

void SubstraitCall::SetOption(std::string_view option_name,
const std::vector<std::string_view>& option_preferences) {
auto& prefs = options_[std::string(option_name)];
for (std::string_view pref : option_preferences) {
prefs.emplace_back(pref);
}
}

// A builder used when creating a Substrait plan from an Arrow execution plan. In
// that situation we do not have a set of anchor values already defined so we keep
// a map of what Ids we have seen.
Expand Down Expand Up @@ -645,50 +663,91 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
};

template <typename Enum>
using EnumParser = std::function<Result<Enum>(std::optional<std::string_view>)>;

template <typename Enum>
EnumParser<Enum> GetEnumParser(const std::vector<std::string>& options) {
std::unordered_map<std::string, Enum> parse_map;
for (std::size_t i = 0; i < options.size(); i++) {
parse_map[options[i]] = static_cast<Enum>(i + 1);
class EnumParser {
public:
explicit EnumParser(const std::vector<std::string>& options) {
for (std::size_t i = 0; i < options.size(); i++) {
parse_map_[options[i]] = static_cast<Enum>(i + 1);
reverse_map_[static_cast<Enum>(i + 1)] = options[i];
}
}
return [parse_map](std::optional<std::string_view> enum_val) -> Result<Enum> {
if (!enum_val) {
// Assumes 0 is always kUnspecified in Enum
return static_cast<Enum>(0);

Result<Enum> Parse(std::string_view enum_val) const {
auto it = parse_map_.find(std::string(enum_val));
if (it == parse_map_.end()) {
return Status::NotImplemented("The value ", enum_val,
" is not an expected enum value");
}
auto maybe_parsed = parse_map.find(std::string(*enum_val));
if (maybe_parsed == parse_map.end()) {
return Status::Invalid("The value ", *enum_val, " is not an expected enum value");
return it->second;
}

std::string ImplementedOptionsAsString(
const std::vector<Enum>& implemented_opts) const {
std::vector<std::string_view> opt_strs;
for (const Enum& implemented_opt : implemented_opts) {
auto it = reverse_map_.find(implemented_opt);
if (it == reverse_map_.end()) {
opt_strs.emplace_back("Unknown");
} else {
opt_strs.emplace_back(it->second);
}
}
return maybe_parsed->second;
};
}
return arrow::internal::JoinStrings(opt_strs, ", ");
}

private:
std::unordered_map<std::string, Enum> parse_map_;
std::unordered_map<Enum, std::string> reverse_map_;
};

enum class TemporalComponent { kUnspecified = 0, kYear, kMonth, kDay, kSecond };
static std::vector<std::string> kTemporalComponentOptions = {"YEAR", "MONTH", "DAY",
"SECOND"};
static EnumParser<TemporalComponent> kTemporalComponentParser =
GetEnumParser<TemporalComponent>(kTemporalComponentOptions);
static EnumParser<TemporalComponent> kTemporalComponentParser(kTemporalComponentOptions);

enum class OverflowBehavior { kUnspecified = 0, kSilent, kSaturate, kError };
static std::vector<std::string> kOverflowOptions = {"SILENT", "SATURATE", "ERROR"};
static EnumParser<OverflowBehavior> kOverflowParser =
GetEnumParser<OverflowBehavior>(kOverflowOptions);
static EnumParser<OverflowBehavior> kOverflowParser(kOverflowOptions);

template <typename Enum>
Result<Enum> ParseEnumArg(const SubstraitCall& call, uint32_t arg_index,
Result<Enum> ParseOptionOrElse(const SubstraitCall& call, std::string_view option_name,
const EnumParser<Enum>& parser,
const std::vector<Enum>& implemented_options,
Enum fallback) {
std::optional<std::vector<std::string> const*> enum_arg = call.GetOption(option_name);
if (!enum_arg.has_value()) {
return fallback;
}
std::vector<std::string> const* prefs = *enum_arg;
for (const std::string& pref : *prefs) {
ARROW_ASSIGN_OR_RAISE(Enum parsed, parser.Parse(pref));
for (Enum implemented_opt : implemented_options) {
if (implemented_opt == parsed) {
return parsed;
}
}
}

// Prepare error message
return Status::NotImplemented(
"During a call to a function with id ", call.id().uri, "#", call.id().name,
" the plan requested the option ", option_name, " to be one of [",
arrow::internal::JoinStrings(*prefs, ", "),
"] but the only supported options are [",
parser.ImplementedOptionsAsString(implemented_options), "]");
}

template <typename Enum>
Result<Enum> ParseEnumArg(const SubstraitCall& call, int arg_index,
const EnumParser<Enum>& parser) {
ARROW_ASSIGN_OR_RAISE(std::optional<std::string_view> enum_arg,
call.GetEnumArg(arg_index));
return parser(enum_arg);
ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(arg_index));
return parser.Parse(enum_val);
}

Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
int start_index) {
std::vector<compute::Expression> expressions;
for (uint32_t index = start_index; index < call.size(); index++) {
for (int index = start_index; index < call.size(); index++) {
ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(index));
expressions.push_back(arg);
}
Expand All @@ -698,13 +757,13 @@ Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic(
const std::string& function_name) {
return [function_name](const SubstraitCall& call) -> Result<compute::Expression> {
ARROW_ASSIGN_OR_RAISE(OverflowBehavior overflow_behavior,
ParseEnumArg(call, 0, kOverflowParser));
ARROW_ASSIGN_OR_RAISE(
OverflowBehavior overflow_behavior,
ParseOptionOrElse(call, "overflow", kOverflowParser,
{OverflowBehavior::kSilent, OverflowBehavior::kError},
OverflowBehavior::kSilent));
ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
GetValueArgs(call, 1));
if (overflow_behavior == OverflowBehavior::kUnspecified) {
overflow_behavior = OverflowBehavior::kSilent;
}
GetValueArgs(call, 0));
if (overflow_behavior == OverflowBehavior::kSilent) {
return arrow::compute::call(function_name, std::move(value_args));
} else if (overflow_behavior == OverflowBehavior::kError) {
Expand All @@ -727,12 +786,12 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic
SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
/*nullable=*/true);
if (kChecked) {
substrait_call.SetEnumArg(0, "ERROR");
substrait_call.SetOption("overflow", {"ERROR"});
} else {
substrait_call.SetEnumArg(0, "SILENT");
substrait_call.SetOption("overflow", {"SILENT"});
}
for (std::size_t i = 0; i < call.arguments.size(); i++) {
substrait_call.SetValueArg(static_cast<uint32_t>(i + 1), call.arguments[i]);
substrait_call.SetValueArg(static_cast<int>(i), call.arguments[i]);
}
return std::move(substrait_call);
};
Expand All @@ -746,14 +805,14 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrai
SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
/*nullable=*/true);
for (std::size_t i = 0; i < call.arguments.size(); i++) {
substrait_call.SetValueArg(static_cast<uint32_t>(i), call.arguments[i]);
substrait_call.SetValueArg(static_cast<int>(i), call.arguments[i]);
}
return std::move(substrait_call);
};
}

ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping(
const std::string& function_name, uint32_t max_args) {
const std::string& function_name, int max_args) {
return [function_name,
max_args](const SubstraitCall& call) -> Result<compute::Expression> {
if (call.size() > max_args) {
Expand Down
25 changes: 15 additions & 10 deletions cpp/src/arrow/engine/substrait/extension_set.h
Expand Up @@ -119,13 +119,17 @@ class SubstraitCall {
bool output_nullable() const { return output_nullable_; }
bool is_hash() const { return is_hash_; }

bool HasEnumArg(uint32_t index) const;
Result<std::optional<std::string_view>> GetEnumArg(uint32_t index) const;
void SetEnumArg(uint32_t index, std::optional<std::string> enum_arg);
Result<compute::Expression> GetValueArg(uint32_t index) const;
bool HasValueArg(uint32_t index) const;
void SetValueArg(uint32_t index, compute::Expression value_arg);
uint32_t size() const { return size_; }
bool HasEnumArg(int index) const;
Result<std::string_view> GetEnumArg(int index) const;
void SetEnumArg(int index, std::string enum_arg);
Result<compute::Expression> GetValueArg(int index) const;
bool HasValueArg(int index) const;
void SetValueArg(int index, compute::Expression value_arg);
std::optional<std::vector<std::string> const*> GetOption(
std::string_view option_name) const;
void SetOption(std::string_view option_name,
const std::vector<std::string_view>& option_preferences);
int size() const { return size_; }

private:
Id id_;
Expand All @@ -134,9 +138,10 @@ class SubstraitCall {
// Only needed when converting from Substrait -> Arrow aggregates. The
// Arrow function name depends on whether or not there are any groups
bool is_hash_;
std::unordered_map<uint32_t, std::optional<std::string>> enum_args_;
std::unordered_map<uint32_t, compute::Expression> value_args_;
uint32_t size_ = 0;
std::unordered_map<int, std::string> enum_args_;
std::unordered_map<int, compute::Expression> value_args_;
std::unordered_map<std::string, std::vector<std::string>> options_;
int size_ = 0;
};

/// Substrait identifies functions and custom data types using a (uri, name) pair.
Expand Down

0 comments on commit 73efef8

Please sign in to comment.