Skip to content

Commit

Permalink
Drop "one of" default attribute check in LabelEncoder (#5673)
Browse files Browse the repository at this point in the history
### Description
Drop "one of" checking on the default value in the `LabelEncoder`.

### Motivation and Context
When implementing `LabelEncoder` in onnxruntime (for the upcoming
ai.onnx.ml opset 4), it became clear that the implemented type inference
for LabelEncoder would not work well in practice - specifically, the
type inference pass ensures that only one of the default_* attributes is
set. Since some of these are optional with default values, this is not
very workable in practice. This PR addresses this by dropping this
check. Notice that this was [not
checked](https://github.com/onnx/onnx/blob/f556cffc40dc769a2c8e3bb5fd9a90f1e7fbb7eb/onnx/defs/traditionalml/old.cc#L277-L338)
in previous opsets anyway.

---------

Signed-off-by: Aditya Goel <agoel4512@gmail.com>
  • Loading branch information
adityagoel4512 committed Oct 16, 2023
1 parent a93ec80 commit 5f908a9
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions onnx/defs/traditionalml/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,28 +382,6 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
fail_shape_inference(
"At least one of values_tensor, values_strings, values_int64s, values_floats must be set.");
}

int default_length, default_type;
std::tie(default_type, default_length) = getAttributeElementTypeAndLength(
ctx, {"default_tensor", "default_string", "default_int64", "default_float"});
if (default_type != TensorProto::UNDEFINED) {
if (value_type != default_type) {
fail_shape_inference(
"The value type ",
value_type,
" and the default type ",
default_type,
" are different, which is not permitted for LabelEncoders.");
}

// Ensure default_tensor is a singleton if set
const AttributeProto* default_tensor = ctx.getAttribute("default_tensor");
if (default_tensor != nullptr &&
(default_tensor->t().dims_size() != 1 || default_tensor->t().dims(0) != 1)) {
fail_shape_inference("default_tensor must be a singleton if set.");
}
}

if (value_length != key_length) {
fail_shape_inference(
"The number of keys ",
Expand All @@ -413,6 +391,22 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
" must be the same in the LabelEncoder.");
}

auto default_attr = ctx.getAttribute("default_tensor");
if (nullptr != default_attr && default_attr->has_t() && default_attr->t().has_data_type() &&
default_attr->t().data_type() != TensorProto_DataType_UNDEFINED) {
auto default_tensor = default_attr->t();
if (default_tensor.data_type() != value_type) {
fail_shape_inference(
"The default tensor type ",
default_tensor.data_type(),
" and the value type ",
value_type,
" must be the same in the LabelEncoder.");
}
if (1 != default_tensor.dims_size() || 1 != default_tensor.dims(0)) {
fail_shape_inference("The default tensor must be a singleton 1D tensor.");
}
}
// Propagate shape from input type and assign output type based on value type
ctx.getOutputType(0)->mutable_tensor_type()->set_elem_type(value_type);
propagateShapeFromInputToOutput(ctx, 0, 0);
Expand Down

0 comments on commit 5f908a9

Please sign in to comment.