From 9ce31d70d1252ab69ad3d7be0b683a046827b5ce Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 9 May 2023 09:25:41 +0200 Subject: [PATCH] gen4 planner: allow last_insert_id with arguments (#13026) (#13036) Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_rewriting.go | 11 ++++-- .../planbuilder/testdata/dml_cases.json | 19 ++++++++++ .../planbuilder/testdata/select_cases.json | 35 ++++++++++++++++++- 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index 9253a451567..a7705d14049 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -493,13 +493,20 @@ var funcRewrites = map[string]string{ } func (er *astRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { - bindVar, found := funcRewrites[node.Name.Lowered()] + lowered := node.Name.Lowered() + if lowered == "last_insert_id" && len(node.Exprs) > 0 { + // if we are dealing with is LAST_INSERT_ID() with an argument, we don't need to rewrite it. + // with an argument, this is an identity function that will update the session state and + // sets the correct fields in the OK TCP packet that we send back + return + } + bindVar, found := funcRewrites[lowered] if found { if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc { return } if len(node.Exprs) > 0 { - er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered()) + er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", lowered) return } cursor.Replace(bindVarExpression(bindVar)) diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 469a75f203e..c6c0787a6b3 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -4605,5 +4605,24 @@ "Table": "user" } } + }, + { + "comment": "update using last_insert_id with an argument", + "query": "update main.m1 set foo = last_insert_id(foo+1) where id = 12345", + "plan": { + "QueryType": "UPDATE", + "Original": "update main.m1 set foo = last_insert_id(foo+1) where id = 12345", + "Instructions": { + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "MultiShardAutocommit": false, + "Query": "update m1 set foo = last_insert_id(foo + 1) where id = 12345" + } + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 46c961a54d3..c37edeef996 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -4779,7 +4779,6 @@ ] } } - }, { "comment": "Unmergeable subquery with multiple levels of derived statements, using a multi value `IN` predicate", @@ -5708,5 +5707,39 @@ "Vindex": "user_index" } } + }, + { + "comment": "allow last_insert_id with argument", + "query": "select last_insert_id(id) from user", + "v3-plan": { + "QueryType": "SELECT", + "Original": "select last_insert_id(id) from user", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select last_insert_id(id) from `user` where 1 != 1", + "Query": "select last_insert_id(id) from `user`", + "Table": "`user`" + } + }, + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select last_insert_id(id) from user", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select last_insert_id(id) from `user` where 1 != 1", + "Query": "select last_insert_id(id) from `user`", + "Table": "`user`" + } + } } ]