Skip to content

Commit 7398767

Browse files
l46kokcopybara-github
authored andcommittedAug 30, 2024·
Provide an overload to accept a depth level in flatten function
PiperOrigin-RevId: 669346696
1 parent 6c71f16 commit 7398767

File tree

3 files changed

+66
-14
lines changed

3 files changed

+66
-14
lines changed
 

‎extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java

+18-10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import dev.cel.common.CelFunctionDecl;
2222
import dev.cel.common.CelOverloadDecl;
2323
import dev.cel.common.types.ListType;
24+
import dev.cel.common.types.SimpleType;
2425
import dev.cel.common.types.TypeParamType;
2526
import dev.cel.compiler.CelCompilerLibrary;
2627
import dev.cel.runtime.CelRuntime;
@@ -44,11 +45,18 @@ public enum Function {
4445
"list_flatten",
4546
"Flattens a list by a single level",
4647
ListType.create(LIST_PARAM_TYPE),
47-
ListType.create(ListType.create(LIST_PARAM_TYPE)))),
48-
// TODO: add list_flatten_list_int
48+
ListType.create(ListType.create(LIST_PARAM_TYPE))),
49+
CelOverloadDecl.newMemberOverload(
50+
"list_flatten_list_int",
51+
"Flattens a list to the specified level. A negative depth value flattens the list"
52+
+ " recursively to its deepest level.",
53+
ListType.create(SimpleType.DYN),
54+
ListType.create(SimpleType.DYN),
55+
SimpleType.INT)),
56+
CelRuntime.CelFunctionBinding.from(
57+
"list_flatten", Collection.class, list -> flatten(list, 1)),
4958
CelRuntime.CelFunctionBinding.from(
50-
"list_flatten", Collection.class, list -> flatten(list, 1))),
51-
;
59+
"list_flatten_list_int", Collection.class, Long.class, CelListsExtensions::flatten));
5260

5361
private final CelFunctionDecl functionDecl;
5462
private final ImmutableSet<CelFunctionBinding> functionBindings;
@@ -84,15 +92,15 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
8492
}
8593

8694
@SuppressWarnings("unchecked")
87-
private static ImmutableList<Object> flatten(Collection<Object> list, int level) {
88-
Preconditions.checkArgument(level == 1, "recursive flatten is not supported yet.");
95+
private static ImmutableList<Object> flatten(Collection<Object> list, long depth) {
96+
Preconditions.checkArgument(depth >= 0, "Level must be non-negative");
8997
ImmutableList.Builder<Object> builder = ImmutableList.builder();
9098
for (Object element : list) {
91-
if (element instanceof Collection) {
92-
Collection<Object> listItem = (Collection<Object>) element;
93-
builder.addAll(listItem);
94-
} else {
99+
if (!(element instanceof Collection) || depth == 0) {
95100
builder.add(element);
101+
} else {
102+
Collection<Object> listItem = (Collection<Object>) element;
103+
builder.addAll(flatten(listItem, depth - 1));
96104
}
97105
}
98106

‎extensions/src/main/java/dev/cel/extensions/README.md

+18-4
Original file line numberDiff line numberDiff line change
@@ -413,24 +413,38 @@ zero-based.
413413

414414
### Flatten
415415

416-
Flattens a list by one level. Support for flattening to a specified level
417-
will be provided in the future.
416+
Flattens a list by one level, or to the specified level. Providing a negative level will error.
418417

419418
Examples:
420419

421420
```
421+
// Single-level flatten:
422+
422423
[].flatten() // []
423424
[1,[2,3],[4]].flatten() // [1, 2, 3, 4]
424425
[1,[2,[3,4]]].flatten() // [1, 2, [3, 4]]
425426
[1,2,[],[],[3,4]].flatten() // [1, 2, 3, 4]
427+
428+
// Recursive flatten
429+
[1,[2,[3,[4]]]].flatten(2) // return [1, 2, 3, [4]]
430+
[1,[2,[3,[4]]]].flatten(3) // return [1, 2, 3, 4]
431+
432+
// Error
433+
[1,[2,[3,[4]]]].flatten(-1)
426434
```
427435

428436
Note that due to the current limitations of type-checker, a compilation error
429-
will occur if an already flat list is populated. For time being, you must wrap
430-
the list in dyn if you anticipate having to deal with a flat list:
437+
will occur if an already flat list is populated to the argument-less flatten
438+
function.
439+
440+
For time being, you must explicitly provide 1 as the depth level, or wrap the
441+
list in dyn if you anticipate having to deal with a flat list:
431442

432443
```
433444
[1,2,3].flatten() // error
445+
446+
// But the following will work:
447+
[1,2,3].flatten(1) // [1,2,3]
434448
dyn([1,2,3]).flatten() // [1,2,3]
435449
```
436450

‎extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java

+30
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import dev.cel.bundle.CelFactory;
2323
import dev.cel.common.CelValidationException;
2424
import dev.cel.parser.CelStandardMacro;
25+
import dev.cel.runtime.CelEvaluationException;
2526
import org.junit.Test;
2627
import org.junit.runner.RunWith;
2728

@@ -50,6 +51,35 @@ public void flattenSingleLevel_success(String expression) throws Exception {
5051
assertThat(result).isTrue();
5152
}
5253

54+
@Test
55+
@TestParameters("{expression: '[1,2,3,4].flatten(1) == [1,2,3,4]'}")
56+
@TestParameters("{expression: '[1,[2,[3,[4]]]].flatten(0) == [1,[2,[3,[4]]]]'}")
57+
@TestParameters("{expression: '[1,[2,[3,[4]]]].flatten(2) == [1,2,3,[4]]'}")
58+
@TestParameters("{expression: '[1,[2,[3,4]]].flatten(2) == [1,2,3,4]'}")
59+
@TestParameters("{expression: '[[], [[]], [[[]]]].flatten(2) == [[]]'}")
60+
@TestParameters("{expression: '[[], [[]], [[[]]]].flatten(3) == []'}")
61+
@TestParameters("{expression: '[[], [[]], [[[]]]].flatten(4) == []'}")
62+
// The overload with the depth accepts and returns a List(dyn), so the following is permitted.
63+
@TestParameters("{expression: '[1].flatten(1) == [1]'}")
64+
public void flatten_withDepthValue_success(String expression) throws Exception {
65+
boolean result = (boolean) CEL.createProgram(CEL.compile(expression).getAst()).eval();
66+
67+
assertThat(result).isTrue();
68+
}
69+
70+
@Test
71+
public void flatten_negativeDepth_throws() {
72+
CelEvaluationException e =
73+
assertThrows(
74+
CelEvaluationException.class,
75+
() -> CEL.createProgram(CEL.compile("[1,2,3,4].flatten(-1)").getAst()).eval());
76+
77+
assertThat(e)
78+
.hasMessageThat()
79+
.contains("evaluation error: Function 'list_flatten_list_int' failed");
80+
assertThat(e).hasCauseThat().hasMessageThat().isEqualTo("Level must be non-negative");
81+
}
82+
5383
@Test
5484
@TestParameters("{expression: '[1].flatten()'}")
5585
@TestParameters("{expression: '[{1: 2}].flatten()'}")

0 commit comments

Comments
 (0)
Please sign in to comment.