Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TreeEnsemble base values for the reference implementation #5665

Merged
merged 10 commits into from
Oct 18, 2023
4 changes: 2 additions & 2 deletions docs/Changelog-ml.md
Original file line number Diff line number Diff line change
Expand Up @@ -1085,9 +1085,9 @@ This version of the operator has been available since version 3 of the 'ai.onnx.
<dt><tt>aggregate_function</tt> : string (default is SUM)</dt>
<dd>Defines how to aggregate leaf values within a target. <br>One of 'AVERAGE,' 'SUM,' 'MIN,' 'MAX.'</dd>
<dt><tt>base_values</tt> : list of floats</dt>
<dd>Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dd>Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dt><tt>base_values_as_tensor</tt> : tensor</dt>
<dd>Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dd>Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dt><tt>n_targets</tt> : int</dt>
<dd>The total number of targets.</dd>
<dt><tt>nodes_falsenodeids</tt> : list of ints</dt>
Expand Down
4 changes: 2 additions & 2 deletions docs/Operators-ml.md
Original file line number Diff line number Diff line change
Expand Up @@ -1038,9 +1038,9 @@ Other versions of this operator: <a href="Changelog-ml.md#ai.onnx.ml.TreeEnsembl
<dt><tt>aggregate_function</tt> : string (default is SUM)</dt>
<dd>Defines how to aggregate leaf values within a target. <br>One of 'AVERAGE,' 'SUM,' 'MIN,' 'MAX.'</dd>
<dt><tt>base_values</tt> : list of floats</dt>
<dd>Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dd>Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dt><tt>base_values_as_tensor</tt> : tensor</dt>
<dd>Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dd>Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)</dd>
<dt><tt>n_targets</tt> : int</dt>
<dd>The total number of targets.</dd>
<dt><tt>nodes_falsenodeids</tt> : list of ints</dt>
Expand Down
4 changes: 2 additions & 2 deletions onnx/defs/traditionalml/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1004,12 +1004,12 @@
std::string("SUM"))
.Attr(
"base_values",
"Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)",
"Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)",

Check warning on line 1007 in onnx/defs/traditionalml/defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnx/defs/traditionalml/defs.cc:1007: Lines should be <= 120 characters long [whitespace/line_length] [2]
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"base_values_as_tensor",
"Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)",
"Base values for regression, added to final prediction after applying aggregate_function; the size must be the same as the classes or can be left unassigned (assumed 0)",

Check warning on line 1012 in onnx/defs/traditionalml/defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnx/defs/traditionalml/defs.cc:1012: Lines should be <= 120 characters long [whitespace/line_length] [2]
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
Expand Down
4 changes: 3 additions & 1 deletion onnx/reference/ops/aionnxml/op_tree_ensemble_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ def _run( # type: ignore
)
if aggregate_function == "AVERAGE":
res /= n_trees

# Convention is to add base_values after aggregate function
if base_values is not None:
res[:, :] = np.array(base_values).reshape((1, -1))
res[:, :] += np.array(base_values).reshape((1, -1))

if post_transform in (None, "NONE"):
return (res,)
Expand Down
33 changes: 23 additions & 10 deletions onnx/test/reference_evaluator_ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np # type: ignore
from numpy.testing import assert_allclose # type: ignore
from parameterized import parameterized

from onnx import ONNX_ML, TensorProto, TypeProto, ValueInfoProto
from onnx.checker import check_model
Expand Down Expand Up @@ -757,7 +758,7 @@ def test_linear_classifier_unary(self):

@staticmethod
def _get_test_tree_ensemble_regressor(
aggregate_function, rule="BRANCH_LEQ", unique_targets=False
aggregate_function, rule="BRANCH_LEQ", unique_targets=False, base_values=None
):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])
Expand Down Expand Up @@ -786,6 +787,7 @@ def _get_test_tree_ensemble_regressor(
domain="ai.onnx.ml",
n_targets=1,
aggregate_function=aggregate_function,
base_values=base_values,
nodes_falsenodeids=[4, 3, 0, 0, 0, 2, 0, 4, 0, 0],
nodes_featureids=[0, 2, 0, 0, 0, 0, 0, 2, 0, 0],
nodes_hitrates=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
Expand Down Expand Up @@ -828,23 +830,34 @@ def _get_test_tree_ensemble_regressor(
check_model(onx)
return onx

@parameterized.expand(
[
(f"{agg}_{base_values}", base_values, agg)
for base_values in (None, [1.0])
for agg in ("SUM", "AVERAGE", "MIN", "MAX")
]
)
@unittest.skipIf(not ONNX_ML, reason="onnx not compiled with ai.onnx.ml")
def test_tree_ensemble_regressor(self):
def test_tree_ensemble_regressor(self, name, base_values, agg):
self.assertTrue(ONNX_ML)
del name # variable only used to print test name
x = np.arange(9).reshape((-1, 3)).astype(np.float32) / 10 - 0.5
expected_agg = {
"SUM": np.array([[0.576923], [0.576923], [0.576923]], dtype=np.float32),
"AVERAGE": np.array([[0.288462], [0.288462], [0.288462]], dtype=np.float32),
"MIN": np.array([[0.076923], [0.076923], [0.076923]], dtype=np.float32),
"MAX": np.array([[0.5], [0.5], [0.5]], dtype=np.float32),
}
for agg in ("SUM", "AVERAGE", "MIN", "MAX"):
expected = expected_agg[agg]
with self.subTest(aggregate_function=agg):
onx = self._get_test_tree_ensemble_regressor(agg)
self._check_ort(onx, {"X": x}, equal=True)
sess = ReferenceEvaluator(onx)
got = sess.run(None, {"X": x})
assert_allclose(expected, got[0], atol=1e-6)

expected = expected_agg[agg]
if base_values is not None:
expected += base_values[0]
with self.subTest(aggregate_function=agg):
onx = self._get_test_tree_ensemble_regressor(agg, base_values=base_values)
self._check_ort(onx, {"X": x}, equal=True)
sess = ReferenceEvaluator(onx)
got = sess.run(None, {"X": x})
assert_allclose(expected, got[0], atol=1e-6)

@unittest.skipIf(not ONNX_ML, reason="onnx not compiled with ai.onnx.ml")
def test_tree_ensemble_regressor_rule(self):
Expand Down