-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement version converter to complete DFT-20. Requires #5514 Reference: #5514 (comment) --------- Signed-off-by: Justin Chu <justinchu@microsoft.com> Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Ganesan Ramalingam <grama@microsoft.com>
- Loading branch information
1 parent
b5111b8
commit e11dacf
Showing
6 changed files
with
265 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// Copyright (c) ONNX Project Contributors | ||
|
||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "onnx/version_converter/adapters/adapter.h" | ||
|
||
namespace ONNX_NAMESPACE { | ||
namespace version_conversion { | ||
|
||
class AxisAttributeToInput : public Adapter { | ||
public: | ||
AxisAttributeToInput( | ||
const std::string& op_name, | ||
const OpSetID& initial, | ||
const OpSetID& target, | ||
size_t axis_index, | ||
int64_t default_axis) | ||
: Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {} | ||
|
||
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override { | ||
if (node->hasAttribute(kaxis)) { | ||
AttrToInput(graph, node, node->i(kaxis), this->axis_index); | ||
node->removeAttribute(kaxis); | ||
return node; | ||
} | ||
|
||
// Fill in the default value for axis | ||
AttrToInput(graph, node, default_axis, this->axis_index); | ||
return node; | ||
} | ||
|
||
private: | ||
size_t axis_index; | ||
int64_t default_axis; | ||
|
||
void AttrToInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis, size_t axis_index) const { | ||
const ArrayRef<Value*>& inputs = node->inputs(); | ||
|
||
// Add the optional inputs if they don't exist | ||
for (size_t i = inputs.size(); i < axis_index; ++i) { | ||
Node* empty_input = graph->create(kUndefined); | ||
empty_input->insertBefore(node); | ||
node->addInput(empty_input->output()); | ||
} | ||
|
||
// Add the axis input | ||
Node* constant = CreateAxisInput(graph, node, axis); | ||
node->addInput(constant->output()); | ||
} | ||
|
||
Node* CreateAxisInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis) const { | ||
Tensor t; | ||
t.elem_type() = TensorProto_DataType_INT64; | ||
t.sizes() = std::vector<int64_t>{}; | ||
auto& data = t.int64s(); | ||
data.emplace_back(axis); | ||
|
||
Node* constant = graph->create(kConstant); | ||
constant->insertBefore(node); | ||
constant->t_(kvalue, t); | ||
return constant; | ||
} | ||
}; | ||
|
||
} // namespace version_conversion | ||
} // namespace ONNX_NAMESPACE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// Copyright (c) ONNX Project Contributors | ||
|
||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "onnx/version_converter/adapters/adapter.h" | ||
|
||
namespace ONNX_NAMESPACE { | ||
namespace version_conversion { | ||
class AxisInputToAttribute : public Adapter { | ||
public: | ||
explicit AxisInputToAttribute( | ||
const std::string& op_name, | ||
const OpSetID& initial, | ||
const OpSetID& target, | ||
size_t axis_index, | ||
int64_t default_axis) | ||
: Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {} | ||
|
||
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override { | ||
if (!HasAxisInput(node)) { | ||
node->i_(kaxis, this->default_axis); | ||
return EnsureAndReturnNode(node); | ||
} | ||
|
||
const ArrayRef<Value*>& inputs = node->inputs(); | ||
Value* axis_val = inputs[this->axis_index]; | ||
Node* axis_node = axis_val->node(); | ||
|
||
if (axis_node->kind() == kConstant) { | ||
HandleConstantNode(node, axis_node, axis_val); | ||
return EnsureAndReturnNode(node); | ||
} | ||
|
||
if (graph->is_constant_initializer(axis_val)) { | ||
HandleInitializerNode(graph, node, axis_val); | ||
return EnsureAndReturnNode(node); | ||
} | ||
|
||
ONNX_ASSERTM(false, "Axis input must be a constant or initializer for promotion to attribute."); | ||
} | ||
|
||
private: | ||
size_t axis_index; | ||
int64_t default_axis; | ||
|
||
bool HasAxisInput(const Node* node) const { | ||
const ArrayRef<const Value*>& inputs = node->inputs(); | ||
return inputs.size() > this->axis_index && inputs[this->axis_index]->node()->kind() != kUndefined; | ||
} | ||
|
||
void HandleConstantNode(Node* node, Node* axis_node, Value* axis_val) const { | ||
const std::vector<int64_t>& int64s = axis_node->t(kvalue).int64s(); | ||
if (int64s.empty()) { | ||
std::string raw_data = axis_node->t(kvalue).raw(); | ||
ONNX_ASSERTM( | ||
raw_data.size() != 0 && raw_data.size() % 8 == 0, | ||
"Raw Data must be non-empty and size must be a multiple of 8"); | ||
const int64_t* raw = reinterpret_cast<const int64_t*>(raw_data.c_str()); | ||
node->i_(kaxis, raw[0]); | ||
} else { | ||
node->i_(kaxis, int64s.at(0)); | ||
} | ||
node->removeInput(this->axis_index); | ||
if (axis_val->uses().size() < 1) { | ||
axis_node->destroy(); | ||
} | ||
} | ||
|
||
void HandleInitializerNode(std::shared_ptr<Graph> graph, Node* node, Value* axis_val) const { | ||
const std::string initializer_name = axis_val->uniqueName(); | ||
for (const auto& initializer : graph->initializers()) { | ||
if (initializer.name() == initializer_name) { | ||
node->i_(kaxis, initializer.int64s().at(0)); | ||
node->removeInput(this->axis_index); | ||
// Remove initializer | ||
if (axis_val->uses().size() < 1) | ||
graph->eraseInitializer(initializer_name); | ||
break; | ||
} | ||
} | ||
} | ||
|
||
inline Node* EnsureAndReturnNode(Node* node) const { | ||
ONNX_ASSERTM(node->hasAttribute(kaxis), "Axis attribute not created. This may be a bug."); | ||
return node; | ||
} | ||
}; | ||
|
||
} // namespace version_conversion | ||
} // namespace ONNX_NAMESPACE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters