Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parser/printer support external data format #5688

Merged
8 changes: 7 additions & 1 deletion docs/Syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ The grammar below describes the syntax:
value-info ::= type id
value-infos ::= value-info (',' value-info)*
value-info-list ::= '(' value-infos? ')'
value-info-or-initializer ::= type id [ = '{' prim-constants '}']
quoted-str :== '"' ([^"])* '"'
str-str :== quoted-str ':' quoted-str
str-str-list :== '[' str-str (',' str-str)* ']'
yocox marked this conversation as resolved.
Show resolved Hide resolved
internal-data ::= '{' prim-constants '}'
external-data ::= str-str-list
constant-data ::= internal-data | external-data
value-info-or-initializer ::= type id [ '=' constant-data ]
value-info-or-initializers ::= value-info-or-initializer (',' value-info-or-initializer)*
input-list ::= '(' value-info-or-initializers? ')'
output-list ::= '(' value-infos? ')'
Expand Down
2 changes: 1 addition & 1 deletion docs/docsgen/source/intro/python.md
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ from onnx.checker import check_model
input = '''
<
ir_version: 8,
opset_import: [ '' : 15]
opset_import: [ "" : 15]
>
agraph (float[I,J] X, float[I] A, float[I] B) => (float[I] Y) {
XA = MatMul(X, A)
Expand Down
108 changes: 60 additions & 48 deletions onnx/defs/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,19 @@
return Status::OK();
}

Status OnnxParser::Parse(StringStringList& stringStringList) {
std::string strval;
do {
auto* metadata = stringStringList.Add();
PARSE_TOKEN(strval);
metadata->set_key(strval);
MATCH(':');
PARSE_TOKEN(strval);
metadata->set_value(strval);
} while (Matches(','));
return Status::OK();
}

Status OnnxParser::Parse(TensorProto& tensorProto) {
tensorProto = TensorProto();
// Parse the concrete tensor-type with numeric dimensions:
Expand Down Expand Up @@ -410,46 +423,52 @@
float floatval = 0.0;
double dblval = 0.0;
std::string strval;
MATCH('{');
if (!Matches('}')) {
do {
switch (static_cast<TensorProto::DataType>(elem_type)) {
case TensorProto::DataType::TensorProto_DataType_INT8:
case TensorProto::DataType::TensorProto_DataType_INT16:
case TensorProto::DataType::TensorProto_DataType_INT32:
case TensorProto::DataType::TensorProto_DataType_UINT8:
case TensorProto::DataType::TensorProto_DataType_UINT16:
case TensorProto::DataType::TensorProto_DataType_BOOL:
PARSE_TOKEN(intval);
// TODO: check values are in the correct range.
tensorProto.add_int32_data(intval);
break;
case TensorProto::DataType::TensorProto_DataType_INT64:
PARSE_TOKEN(intval);
tensorProto.add_int64_data(intval);
break;
case TensorProto::DataType::TensorProto_DataType_UINT32:
case TensorProto::DataType::TensorProto_DataType_UINT64:
PARSE_TOKEN(uintval);
tensorProto.add_uint64_data(uintval);
break;
case TensorProto::DataType::TensorProto_DataType_FLOAT:
PARSE_TOKEN(floatval);
tensorProto.add_float_data(floatval);
break;
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
PARSE_TOKEN(dblval);
tensorProto.add_double_data(dblval);
break;
case TensorProto::DataType::TensorProto_DataType_STRING:
PARSE_TOKEN(strval);
tensorProto.add_string_data(strval);
break;
default:
return ParseError("Unhandled type: %d", elem_type);
}
} while (Matches(','));
MATCH('}');
if (Matches('{')) {
if (!Matches('}')) {
do {
switch (static_cast<TensorProto::DataType>(elem_type)) {
case TensorProto::DataType::TensorProto_DataType_INT8:
case TensorProto::DataType::TensorProto_DataType_INT16:
case TensorProto::DataType::TensorProto_DataType_INT32:
case TensorProto::DataType::TensorProto_DataType_UINT8:
case TensorProto::DataType::TensorProto_DataType_UINT16:
case TensorProto::DataType::TensorProto_DataType_BOOL:
PARSE_TOKEN(intval);
// TODO: check values are in the correct range.

Check warning on line 437 in onnx/defs/parser.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnx/defs/parser.cc:437: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
tensorProto.add_int32_data(intval);
break;
case TensorProto::DataType::TensorProto_DataType_INT64:
PARSE_TOKEN(intval);
tensorProto.add_int64_data(intval);
break;
case TensorProto::DataType::TensorProto_DataType_UINT32:
case TensorProto::DataType::TensorProto_DataType_UINT64:
PARSE_TOKEN(uintval);
tensorProto.add_uint64_data(uintval);
break;
case TensorProto::DataType::TensorProto_DataType_FLOAT:
PARSE_TOKEN(floatval);
tensorProto.add_float_data(floatval);
break;
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
PARSE_TOKEN(dblval);
tensorProto.add_double_data(dblval);
break;
case TensorProto::DataType::TensorProto_DataType_STRING:
PARSE_TOKEN(strval);
tensorProto.add_string_data(strval);
break;
default:
return ParseError("Unhandled type: %d", elem_type);
}
} while (Matches(','));
MATCH('}');
}
} else if (Matches('[')) {
tensorProto.set_data_location(TensorProto::DataLocation::TensorProto_DataLocation_EXTERNAL);
auto& externalData = *tensorProto.mutable_external_data();
PARSE(externalData);
MATCH(']');
}
return Status::OK();
}
Expand Down Expand Up @@ -798,14 +817,7 @@
auto& metadata_props = *model.mutable_metadata_props();
MATCH('[');
if (!Matches(']')) {
do {
auto* metadata = metadata_props.Add();
PARSE_TOKEN(strval);
metadata->set_key(strval);
MATCH(':');
PARSE_TOKEN(strval);
metadata->set_value(strval);
} while (Matches(','));
PARSE(metadata_props);
MATCH(']');
}
break;
Expand Down
4 changes: 4 additions & 0 deletions onnx/defs/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ using TensorList = google::protobuf::RepeatedPtrField<TensorProto>;

using OpsetIdList = google::protobuf::RepeatedPtrField<OperatorSetIdProto>;

using StringStringList = google::protobuf::RepeatedPtrField<StringStringEntryProto>;

#define CHECK_PARSER_STATUS(status) \
{ \
auto local_status_ = status; \
Expand Down Expand Up @@ -388,6 +390,8 @@ class OnnxParser : public ParserBase {

Status Parse(TypeProto& typeProto);

Status Parse(StringStringList& stringStringList);

Status Parse(TensorProto& tensorProto);

Status Parse(AttributeProto& attr);
Expand Down
15 changes: 8 additions & 7 deletions onnx/defs/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

namespace ONNX_NAMESPACE {

using MetaDataProp = StringStringEntryProto;
using MetaDataProps = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
using StringStringEntryProtos = google::protobuf::RepeatedPtrField<StringStringEntryProto>;

class ProtoPrinter {
public:
Expand Down Expand Up @@ -58,11 +57,11 @@

void print(const OpsetIdList& opsets);

void print(const MetaDataProps& metadataprops) {
printSet("[", ", ", "]", metadataprops);
void print(const StringStringEntryProtos& stringStringProtos) {
printSet("[", ", ", "]", stringStringProtos);
}

void print(const MetaDataProp& metadata) {
void print(const StringStringEntryProto& metadata) {
printQuoted(metadata.key());
output_ << ": ";
printQuoted(metadata.value());
Expand Down Expand Up @@ -197,8 +196,10 @@
if (is_initializer) {
output_ << " = ";
}
// TODO: does not yet handle all types or externally stored data.
if (tensor.has_raw_data()) {
// TODO: does not yet handle all types

Check warning on line 199 in onnx/defs/printer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnx/defs/printer.cc:199: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
if (tensor.has_data_location() && tensor.data_location() == TensorProto_DataLocation_EXTERNAL) {
print(tensor.external_data());
} else if (tensor.has_raw_data()) {
switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
case TensorProto::DataType::TensorProto_DataType_INT32:
printSet(" {", ",", "}", ParseData<int32_t>(&tensor));
Expand Down
26 changes: 26 additions & 0 deletions onnx/test/cpp/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,5 +557,31 @@
CheckModel(code);
}

TEST(ParserTest, ExternalDataTest) {
const char* code = R"ONNX(
agraph (float y = {1.0}, float[N] z) => (w) <
float[3, 2] m1 = ["location": "weight_1.bin", "offset": "17"],
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
float[2, 1] m2 = {1.0, 2.0}
>
{
x = Add(y, z)
m = Mul(m1, m1)
}
)ONNX";

GraphProto graph;
Parse(graph, code);

EXPECT_EQ(graph.input_size(), 2);
EXPECT_EQ(graph.output_size(), 1);
EXPECT_EQ(graph.initializer_size(), 3); // m1, m2

Check warning on line 577 in onnx/test/cpp/parser_test.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/test/cpp/parser_test.cc:577: At least two spaces is best between code and comments [whitespace/comments] [2]
EXPECT_EQ(graph.value_info_size(), 0); // x

Check warning on line 578 in onnx/test/cpp/parser_test.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/test/cpp/parser_test.cc:578: At least two spaces is best between code and comments [whitespace/comments] [2]
EXPECT_EQ(graph.initializer().Get(1).data_location(), TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL);
EXPECT_EQ(graph.initializer().Get(1).external_data().Get(0).key(), "location");
EXPECT_EQ(graph.initializer().Get(1).external_data().Get(0).value(), "weight_1.bin");
EXPECT_EQ(graph.initializer().Get(1).external_data().Get(1).key(), "offset");
EXPECT_EQ(graph.initializer().Get(1).external_data().Get(1).value(), "17");
}

} // namespace Test
} // namespace ONNX_NAMESPACE