Skip to content

Commit

Permalink
Copilot and fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Sep 22, 2023
1 parent 408d8c2 commit f778bb6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
17 changes: 11 additions & 6 deletions onnx/version_converter/adapters/axis_attribute_to_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ namespace version_conversion {

class AxisAttributeToInput : public Adapter {
public:
explicit AxisAttributeToInput(
AxisAttributeToInput(
const std::string& op_name,
const OpSetID& initial,
const OpSetID& target,
int64_t axis_index,
size_t axis_index,
int64_t default_axis)
: Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}

Expand All @@ -38,20 +38,25 @@ class AxisAttributeToInput : public Adapter {
}

private:
int64_t axis_index;
size_t axis_index;
int64_t default_axis;

void AttrToInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis, int64_t axis_index) const {
void AttrToInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis, size_t axis_index) const {
const ArrayRef<Value*>& inputs = node->inputs();

// Add the optional inputs if they don't exist
while (inputs.size() < axis_index) {
for (size_t i = inputs.size(); i < axis_index; ++i) {
Node* empty_input = graph->create(kUndefined);
empty_input->insertBefore(node);
node->addInput(empty_input->output());
}

// Add the axis input
Node* constant = CreateAxisInput(graph, node, axis);
node->addInput(constant->output());
}

Node* CreateAxisInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis) const {
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
t.sizes() = std::vector<int64_t>{};
Expand All @@ -61,7 +66,7 @@ class AxisAttributeToInput : public Adapter {
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
return constant;
}
};

Expand Down
4 changes: 2 additions & 2 deletions onnx/version_converter/adapters/axis_input_to_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class AxisInputToAttribute : public Adapter {
const std::string& op_name,
const OpSetID& initial,
const OpSetID& target,
int64_t axis_index,
size_t axis_index,
int64_t default_axis)
: Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}

Expand All @@ -45,7 +45,7 @@ class AxisInputToAttribute : public Adapter {
}

private:
int64_t axis_index;
size_t axis_index;
int64_t default_axis;

bool HasAxisInput(const Node* node) const {
Expand Down

0 comments on commit f778bb6

Please sign in to comment.