Skip to content

Commit

Permalink
Update inliner with latest changes from optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones committed Aug 30, 2023
2 parents 16c3bfc + 705546a commit 058f3b0
Show file tree
Hide file tree
Showing 6 changed files with 613 additions and 39 deletions.
127 changes: 115 additions & 12 deletions cel/folding.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,56 @@ import (
"github.com/google/cel-go/common/types/traits"
)

// ConstantFoldingOption defines a functional option for configuring constant folding.
type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error)

// MaxConstantFoldIterations limits the number of times literals may be folding during optimization.
//
// Defaults to 100 if not set.
func MaxConstantFoldIterations(limit int) ConstantFoldingOption {
return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) {
opt.maxFoldIterations = limit
return opt, nil
}
}

// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate
// literal values within function calls and select statements with their evaluated result.
func NewConstantFoldingOptimizer() ASTOptimizer {
return &constantFoldingOptimizer{}
func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) {
folder := &constantFoldingOptimizer{
maxFoldIterations: defaultMaxConstantFoldIterations,
}
var err error
for _, o := range opts {
folder, err = o(folder)
if err != nil {
return nil, err
}
}
return folder, nil
}

type constantFoldingOptimizer struct{}
type constantFoldingOptimizer struct {
maxFoldIterations int
}

// Optimize queries the expression graph for scalar and aggregate literal expressions within call and
// select statements and then evaluates them and replaces the call site with the literal result.
//
// Note: only values which can be represented as literals in CEL syntax are supported.
func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)

// Walk the list of foldable expression and continue to fold until there are no more folds left.
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
// a logic bug with the selection of expressions.
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
for len(foldableExprs) != 0 {
foldCount := 0
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
for _, fold := range foldableExprs {
// If the expression could be folded because it's a non-strict call, and the
// branches are pruned, continue to the next fold.
if fold.Kind() == ast.CallKind && maybePruneBranches(fold) {
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
continue
}
// Otherwise, assume all context is needed to evaluate the expression.
Expand All @@ -58,6 +84,7 @@ func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *as
return a
}
}
foldCount++
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
}
// Once all of the constants have been folded, try to run through the remaining comprehensions
Expand All @@ -72,7 +99,8 @@ func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *as
pruneOptionalElements(ctx, root)

// Ensure that all intermediate values in the folded expression can be represented as valid
// CEL literals within the AST structure.
// CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents`
// to avoid extra allocations during this final pass through the AST.
ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() != ast.LiteralKind {
return
Expand Down Expand Up @@ -117,28 +145,79 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
// but conditional will not be pruned cleanly, so this is one small area where the
// constant folding step reimplements a portion of the evaluator.
func maybePruneBranches(expr ast.NavigableExpr) bool {
func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
call := expr.AsCall()
args := call.Args()
switch call.FunctionName() {
case operators.LogicalAnd, operators.LogicalOr:
return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr)
case operators.Conditional:
args := call.Args()
cond := args[0]
truthy := args[1]
falsy := args[2]
if cond.Kind() != ast.LiteralKind {
return false
}
if cond.AsLiteral() == types.True {
expr.SetKindCase(truthy)
} else {
expr.SetKindCase(falsy)
}
return true
case operators.In:
haystack := args[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
expr.SetKindCase(ctx.NewLiteral(types.False))
return true
}
needle := args[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
expr.SetKindCase(ctx.NewLiteral(types.True))
return true
}
}
}
}
return false
}

func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool {
shortcircuit := types.False
skip := types.True
if function == operators.LogicalOr {
shortcircuit = types.True
skip = types.False
}
newArgs := []ast.Expr{}
for _, arg := range args {
if arg.Kind() != ast.LiteralKind {
newArgs = append(newArgs, arg)
continue
}
if arg.AsLiteral() == skip {
continue
}
if arg.AsLiteral() == shortcircuit {
expr.SetKindCase(arg)
return true
}
}
if len(newArgs) == 1 {
expr.SetKindCase(newArgs[0])
return true
}
expr.SetKindCase(ctx.NewCall(function, newArgs...))
return true
}

// pruneOptionalElements works from the bottom up to resolve optional elements within
// aggregate literals.
//
// Note, may aggregate literals will be resolved as arguments to functions or select
// Note, many aggregate literals will be resolved as arguments to functions or select
// statements, so this method exists to handle the case where the literal could not be
// fully resolved or exists outside of a call, select, or comprehension context.
func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) {
Expand Down Expand Up @@ -198,6 +277,8 @@ func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
entry := e.AsMapEntry()
key := entry.Key()
val := entry.Value()
// If the entry is not optional, or the value-side of the optional hasn't
// been resolved to a literal, then preserve the entry as-is.
if !entry.IsOptional() || val.Kind() != ast.LiteralKind {
updatedEntries = append(updatedEntries, e)
continue
Expand All @@ -207,6 +288,8 @@ func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
updatedEntries = append(updatedEntries, e)
continue
}
// When the key is not a literal, but the value is, then it needs to be
// restored to an optional value.
if key.Kind() != ast.LiteralKind {
undoOptVal, err := adaptLiteral(ctx, optElemVal)
if err != nil {
Expand Down Expand Up @@ -403,14 +486,14 @@ func constantCallMatcher(e ast.NavigableExpr) bool {
fnName := call.FunctionName()
if fnName == operators.LogicalAnd {
for _, child := range children {
if child.Kind() == ast.LiteralKind && child.AsLiteral() == types.False {
if child.Kind() == ast.LiteralKind {
return true
}
}
}
if fnName == operators.LogicalOr {
for _, child := range children {
if child.Kind() == ast.LiteralKind && child.AsLiteral() == types.True {
if child.Kind() == ast.LiteralKind {
return true
}
}
Expand All @@ -421,6 +504,22 @@ func constantCallMatcher(e ast.NavigableExpr) bool {
return true
}
}
if fnName == operators.In {
haystack := children[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
return true
}
needle := children[0]
if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
needleValue := needle.AsLiteral()
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
return true
}
}
}
}
// convert all other calls with constant arguments
for _, child := range children {
if !constantMatcher(child) {
Expand Down Expand Up @@ -448,3 +547,7 @@ func aggregateLiteralMatcher(e ast.NavigableExpr) bool {
var (
constantMatcher = ast.ConstantValueMatcher()
)

const (
defaultMaxConstantFoldIterations = 100
)

0 comments on commit 058f3b0

Please sign in to comment.