Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup shape inference implementation #5596

Merged
merged 16 commits into from
Oct 5, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 18 additions & 24 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@
}
#endif
return ONNX_NAMESPACE::to_string(type.elem_type());

void UnknownOpError(const NodeProto& nodeProto) {
if (checker::check_is_experimental_op(nodeProto)) {
fail_type_inference("Experimental operator '", nodeProto.op_type(), "' no longer upported.");
}
fail_type_inference(
"Unknown operator/function: '", nodeProto.op_type(), "' in domain: '", nodeProto.domain(), "'.");
}
}

std::string GetElemTypeString(const TypeProto_SparseTensor& type) {
Expand All @@ -78,6 +70,13 @@
inline bool IsOnnxDomainOp(const NodeProto& node, const std::string& op_type) {
return (IsOnnxDomain(node.domain()) && (node.op_type() == op_type));
}

void UnknownOpError(const NodeProto& nodeProto) {
if (checker::check_is_experimental_op(nodeProto)) {
fail_type_inference("Experimental operator '", nodeProto.op_type(), "' no longer upported.");
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
}
fail_type_inference("Unknown operator/function: '", nodeProto.op_type(), "' in domain: '", nodeProto.domain(), "'.");
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
}
} // namespace

template <class T>
Expand Down Expand Up @@ -266,9 +265,9 @@
// for inference. For GraphProto, inferred types are stored in the GraphProto
// but FunctionProto does not have a place to store inferred types. So, we
// use a temporary vector (for the duration of inference) to store these.
class InferredTypes {

Check warning on line 268 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: class 'InferredTypes' defines a non-default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator [cppcoreguidelines-special-member-functions] ```cpp class InferredTypes { ^ ```
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
public:
InferredTypes(GraphProto* graph = nullptr) : graph_ptr(graph) {}

Check warning on line 270 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [5] Raw Output: onnx/shape_inference/implementation.cc:270: Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]

Check warning on line 270 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: constructor does not initialize these fields: types, graph_ptr [cppcoreguidelines-pro-type-member-init] onnx/shape_inference/implementation.cc:291: ```diff - std::vector<TypeProto*> types; - GraphProto* graph_ptr; + std::vector<TypeProto*> types{}; + GraphProto* graph_ptr{}; ```

Check warning on line 270 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: single-argument constructors must be marked explicit to avoid unintentional implicit conversions [google-explicit-constructor] ```suggestion explicit InferredTypes(GraphProto* graph = nullptr) : graph_ptr(graph) {} ```

TypeProto* Add(const std::string& var_name, const TypeProto& type) {
if (graph_ptr != nullptr) {
Expand Down Expand Up @@ -298,7 +297,7 @@
void BindValuesOnCall(
const DataValueMap& caller_map,
const NodeProto& caller,
DataValueMap& callee_map,

Check warning on line 300 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Is this a non-const reference? If so, make const or use a pointer: DataValueMap& callee_map [runtime/references] [2] Raw Output: onnx/shape_inference/implementation.cc:300: Is this a non-const reference? If so, make const or use a pointer: DataValueMap& callee_map [runtime/references] [2]
const FunctionProto& callee) {
auto num_inputs = (std::min)(caller.input_size(), callee.input_size());
for (int i = 0; i < num_inputs; ++i) {
Expand All @@ -317,7 +316,7 @@
void BindValuesOnReturn(
const DataValueMap& callee_map,
const FunctionProto& callee,
DataValueMap& caller_map,

Check warning on line 319 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Is this a non-const reference? If so, make const or use a pointer: DataValueMap& caller_map [runtime/references] [2] Raw Output: onnx/shape_inference/implementation.cc:319: Is this a non-const reference? If so, make const or use a pointer: DataValueMap& caller_map [runtime/references] [2]
const NodeProto& caller) {
auto num_outputs = (std::min)(caller.output_size(), callee.output_size());
for (int i = 0; i < num_outputs; ++i) {
Expand Down Expand Up @@ -360,7 +359,7 @@
}
}

void UpdateType(ValueInfoProto& valueInfo) {

Check warning on line 362 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Is this a non-const reference? If so, make const or use a pointer: ValueInfoProto& valueInfo [runtime/references] [2] Raw Output: onnx/shape_inference/implementation.cc:362: Is this a non-const reference? If so, make const or use a pointer: ValueInfoProto& valueInfo [runtime/references] [2]
if (valueInfo.has_type()) {
value_types_by_name[valueInfo.name()] = valueInfo.mutable_type();
} else {
Expand Down Expand Up @@ -421,7 +420,7 @@
}
}

void ProcessCall(const NodeProto& caller, const FunctionProto& callee, InferenceContext& ctx) {

Check warning on line 423 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Is this a non-const reference? If so, make const or use a pointer: InferenceContext& ctx [runtime/references] [2] Raw Output: onnx/shape_inference/implementation.cc:423: Is this a non-const reference? If so, make const or use a pointer: InferenceContext& ctx [runtime/references] [2]
DataValueMap callee_value_map;
if (generated_shape_data_by_name != nullptr) {
BindValuesOnCall(*generated_shape_data_by_name, caller, callee_value_map, callee);
Expand All @@ -433,7 +432,7 @@
}
}

void Process(NodeProto& n) {

Check warning on line 435 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Is this a non-const reference? If so, make const or use a pointer: NodeProto& n [runtime/references] [2] Raw Output: onnx/shape_inference/implementation.cc:435: Is this a non-const reference? If so, make const or use a pointer: NodeProto& n [runtime/references] [2]
// Resolve domain for node
auto dit = opset_imports.find(n.domain());
if (dit == opset_imports.end()) {
Expand Down Expand Up @@ -470,9 +469,13 @@
ProcessCall(n, *(schema->GetFunction()), ctx);
} else {
// Continue with inference for remaining nodes
// TODO: fix this

Check warning on line 472 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnx/shape_inference/implementation.cc:472: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
return;
}
// check type-constraints specified via type variables
if (options.check_type) {
schema->CheckInputOutputType(ctx);
}
} else if (model_local_functions_map.size() > 0) {
auto iter = model_local_functions_map.find(GetModelLocalFunctionsMapIdentifier(n.domain(), n.op_type()));
if (iter != model_local_functions_map.end()) {
Expand All @@ -483,28 +486,15 @@
} else {
UnknownOpError(n);
}
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
ONNX_HANDLE_EXCEPTION([&]() { inference_errors.push_back(GetErrorWithNodeInfo(n, ex)); });
// Continue with inference for remaining nodes
return;
}

ONNX_TRY {
// check the type-equality for input and output
if (options.check_type && schema) {
schema->CheckInputOutputType(ctx);
}

for (int i = 0; i < n.output_size(); ++i) {
// skip type and shape propagation for missing optional outputs.
if (!n.output(i).empty())
UpdateType(n.output(i), ctx.getOutputType(i));
}

// Constant values are tracked to improve inference/checking for subsequent nodes.
ProcessConstant(n);

// If data propagation is enabled, propagate shape data if it exists.
// If data-propagation is enabled, partial-evaluation (aka data-propagation) is performed
// to improve inference/checking for subsequent nodes.
if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) {
if (generated_shape_data_by_name == nullptr) {
fail_shape_inference(
Expand All @@ -515,7 +505,11 @@
schema->GetDataPropagationFunction()(data_propagation_ctx);
}
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
ONNX_HANDLE_EXCEPTION([&]() { inference_errors.push_back(GetErrorWithNodeInfo(n, ex)); });
}
ONNX_CATCH(const std::runtime_error& err) {
// TODO: Fix this. Unclear if this should be remapped to a shape inference error.

Check warning on line 512 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnx/shape_inference/implementation.cc:512: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
ONNX_HANDLE_EXCEPTION([&]() { fail_shape_inference(GetErrorWithNodeInfo(n, err)); });
}
}
Expand Down Expand Up @@ -544,7 +538,7 @@
}
}

void Process(GraphProto& graph) {

Check warning on line 541 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Is this a non-const reference? If so, make const or use a pointer: GraphProto& graph [runtime/references] [2] Raw Output: onnx/shape_inference/implementation.cc:541: Is this a non-const reference? If so, make const or use a pointer: GraphProto& graph [runtime/references] [2]
if (symbol_table) {
TraverseGraphsToAddExistingSymbols(graph, *symbol_table);
}
Expand Down Expand Up @@ -584,7 +578,7 @@
}
}

void Process(const NodeProto& n, internal::AttributeBinder& attribute_binder) {

Check warning on line 581 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Is this a non-const reference? If so, make const or use a pointer: internal::AttributeBinder& attribute_binder [runtime/references] [2] Raw Output: onnx/shape_inference/implementation.cc:581: Is this a non-const reference? If so, make const or use a pointer: internal::AttributeBinder& attribute_binder [runtime/references] [2]
NodeProto copy_n(n);
attribute_binder.VisitNode(&copy_n);
Process(copy_n);
Expand Down Expand Up @@ -695,7 +689,7 @@
void FinalizeShapeInference() {
auto& errors = getErrors();
// Throw shape inference error if any. Error mode right now only supports 0 and 1.
// When set to 0, any node level shape inference errors are not thrown. This is to support backward compatiblity

Check warning on line 692 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "compatiblity" is a misspelling of "compatibility" Raw Output: ./onnx/shape_inference/implementation.cc:692:104: "compatiblity" is a misspelling of "compatibility"
// with 1.7 and earlier releases. When set to 1 it will throw all exceptions.
// TODO: Add a more granular way for exception handling.
if (!errors.empty() && options.error_mode > 0) {
Expand Down