Skip to content

Commit

Permalink
DatabaseClient uses SQL Supplier more lazily
Browse files Browse the repository at this point in the history
This commit modifies the `DefaultDatabaseClient` implementation in order
to ensure lazier usage of the `Supplier<String>` passed to the sql
method (`DatabaseClient#sql(Supplier)`).

Since technically `DatabaseClient` is an interface that could have 3rd
party implementations, the lazyness expectation is only hinted at in the
`DatabaseClient#sql` javadoc.

Possible caveat: some log statements attempt to reflect the now lazily
resolved SQL string. Similarly, some exceptions can capture the SQL that
caused the issue if known. We expect that these always occur after the
execution of the statement has been attempted (see `ResultFunction`).
At this point the SQL string will be accessible and logs and exceptions
should reflect it as before. Keep an eye out for such strings turning
into `null` after this change, which would indicate the opposite.

Backport of d72df5a
See gh-29367
Closes gh-29887
  • Loading branch information
simonbasle committed Jan 26, 2023
1 parent e4e90bb commit de53d77
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,34 +20,18 @@

import io.r2dbc.spi.Connection;


/**
* Union type combining {@link Function} and {@link SqlProvider} to expose the SQL that is
* related to the underlying action.
* related to the underlying action. The SqlProvider can support lazy / generate once semantics,
* in which case {@link #getSql()} can be {@code null} until the {@code #apply(Connection)}
* method is invoked.
*
* @author Mark Paluch
* @author Simon Baslé
* @since 5.3
* @param <R> the type of the result of the function.
*/
class ConnectionFunction<R> implements Function<Connection, R>, SqlProvider {

private final String sql;

private final Function<Connection, R> function;


ConnectionFunction(String sql, Function<Connection, R> function) {
this.sql = sql;
this.function = function;
}


@Override
public R apply(Connection t) {
return this.function.apply(t);
}

@Override
public String getSql() {
return this.sql;
}
interface ConnectionFunction<R> extends Function<Connection, R>, SqlProvider {
}

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -79,7 +79,10 @@ public interface DatabaseClient extends ConnectionAccessor {
* the execution. The SQL string can contain either native parameter
* bind markers or named parameters (e.g. {@literal :foo, :bar}) when
* {@link NamedParameterExpander} is enabled.
* <p>Accepts {@link PreparedOperation} as SQL and binding {@link Supplier}
* <p>Accepts {@link PreparedOperation} as SQL and binding {@link Supplier}.
* <p>{code DatabaseClient} implementations should defer the resolution of
* the SQL string as much as possible, ideally up to the point where a
* {@code Subscription} happens. This is the case for the default implementation.
* @param sqlSupplier a supplier for the SQL statement
* @return a new {@link GenericExecuteSpec}
* @see NamedParameterExpander
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,6 +60,7 @@
* @author Mark Paluch
* @author Mingyuan Wu
* @author Bogdan Ilchyshyn
* @author Simon Baslé
* @since 5.3
*/
class DefaultDatabaseClient implements DatabaseClient {
Expand Down Expand Up @@ -322,9 +323,8 @@ public Mono<Void> then() {
return fetch().rowsUpdated().then();
}

private <T> FetchSpec<T> execute(Supplier<String> sqlSupplier, BiFunction<Row, RowMetadata, T> mappingFunction) {
String sql = getRequiredSql(sqlSupplier);
Function<Connection, Statement> statementFunction = connection -> {
private ResultFunction getResultFunction(Supplier<String> sqlSupplier) {
BiFunction<Connection, String, Statement> statementFunction = (connection, sql) -> {
if (logger.isDebugEnabled()) {
logger.debug("Executing SQL statement [" + sql + "]");
}
Expand Down Expand Up @@ -370,16 +370,16 @@ private <T> FetchSpec<T> execute(Supplier<String> sqlSupplier, BiFunction<Row, R
return statement;
};

Function<Connection, Flux<Result>> resultFunction = connection -> {
Statement statement = statementFunction.apply(connection);
return Flux.from(this.filterFunction.filter(statement, DefaultDatabaseClient.this.executeFunction))
.cast(Result.class).checkpoint("SQL \"" + sql + "\" [DatabaseClient]");
};
return new ResultFunction(sqlSupplier, statementFunction, this.filterFunction, DefaultDatabaseClient.this.executeFunction);
}

private <T> FetchSpec<T> execute(Supplier<String> sqlSupplier, BiFunction<Row, RowMetadata, T> mappingFunction) {
ResultFunction resultHandler = getResultFunction(sqlSupplier);

return new DefaultFetchSpec<>(
DefaultDatabaseClient.this, sql,
new ConnectionFunction<>(sql, resultFunction),
new ConnectionFunction<>(sql, connection -> sumRowsUpdated(resultFunction, connection)),
DefaultDatabaseClient.this,
resultHandler,
connection -> sumRowsUpdated(resultHandler, connection),
mappingFunction);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,6 @@
import java.util.function.Function;

import io.r2dbc.spi.Connection;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import reactor.core.publisher.Flux;
Expand All @@ -32,31 +31,29 @@
* Default {@link FetchSpec} implementation.
*
* @author Mark Paluch
* @author Simon Baslé
* @since 5.3
* @param <T> the row result type
*/
class DefaultFetchSpec<T> implements FetchSpec<T> {

private final ConnectionAccessor connectionAccessor;

private final String sql;

private final Function<Connection, Flux<Result>> resultFunction;
private final ResultFunction resultFunction;

private final Function<Connection, Mono<Integer>> updatedRowsFunction;

private final BiFunction<Row, RowMetadata, T> mappingFunction;


DefaultFetchSpec(ConnectionAccessor connectionAccessor, String sql,
Function<Connection, Flux<Result>> resultFunction,
DefaultFetchSpec(ConnectionAccessor connectionAccessor,
ResultFunction resultFunction,
Function<Connection, Mono<Integer>> updatedRowsFunction,
BiFunction<Row, RowMetadata, T> mappingFunction) {

this.sql = sql;
this.connectionAccessor = connectionAccessor;
this.resultFunction = resultFunction;
this.updatedRowsFunction = updatedRowsFunction;
this.updatedRowsFunction = new DelegateConnectionFunction<>(resultFunction, updatedRowsFunction);
this.mappingFunction = mappingFunction;
}

Expand All @@ -70,7 +67,7 @@ public Mono<T> one() {
}
if (list.size() > 1) {
return Mono.error(new IncorrectResultSizeDataAccessException(
String.format("Query [%s] returned non unique result.", this.sql),
String.format("Query [%s] returned non unique result.", this.resultFunction.getSql()),
1));
}
return Mono.just(list.get(0));
Expand All @@ -84,7 +81,7 @@ public Mono<T> first() {

@Override
public Flux<T> all() {
return this.connectionAccessor.inConnectionMany(new ConnectionFunction<>(this.sql,
return this.connectionAccessor.inConnectionMany(new DelegateConnectionFunction<>(this.resultFunction,
connection -> this.resultFunction.apply(connection)
.flatMap(result -> result.map(this.mappingFunction))));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.r2dbc.core;

import java.util.function.Function;

import io.r2dbc.spi.Connection;

import org.springframework.lang.Nullable;

/**
* A {@link ConnectionFunction} that delegates to a {@code SqlProvider} and a plain
* {@code Function}.
*
* @author Simon Baslé
* @since 5.3.26
* @param <R> the type of the result of the function.
*/
final class DelegateConnectionFunction<R> implements ConnectionFunction<R> {

private final SqlProvider sql;

private final Function<Connection, R> function;


DelegateConnectionFunction(SqlProvider sql, Function<Connection, R> function) {
this.sql = sql;
this.function = function;
}


@Override
public R apply(Connection t) {
return this.function.apply(t);
}

@Nullable
@Override
public String getSql() {
return this.sql.getSql();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.r2dbc.core;

import java.util.function.BiFunction;
import java.util.function.Supplier;

import io.r2dbc.spi.Connection;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Statement;
import reactor.core.publisher.Flux;

import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* A {@link ConnectionFunction} that produces a {@code Flux} of {@link Result} and that
* defers generation of the SQL until the function has been applied.
* Beforehand, the {@code getSql()} method simply returns {@code null}. The sql String is
* also memoized during application, so that subsequent calls to {@link #getSql()} return
* the same {@code String} without further calls to the {@code Supplier}.
*
* @author Mark Paluch
* @author Simon Baslé
* @since 5.3.26
*/
final class ResultFunction implements ConnectionFunction<Flux<Result>> {

final Supplier<String> sqlSupplier;
final BiFunction<Connection, String, Statement> statementFunction;
final StatementFilterFunction filterFunction;
final ExecuteFunction executeFunction;

@Nullable
String resolvedSql = null;

ResultFunction(Supplier<String> sqlSupplier, BiFunction<Connection, String, Statement> statementFunction, StatementFilterFunction filterFunction, ExecuteFunction executeFunction) {
this.sqlSupplier = sqlSupplier;
this.statementFunction = statementFunction;
this.filterFunction = filterFunction;
this.executeFunction = executeFunction;
}

@Override
public Flux<Result> apply(Connection connection) {
String sql = this.sqlSupplier.get();
Assert.state(StringUtils.hasText(sql), "SQL returned by supplier must not be empty");
this.resolvedSql = sql;
Statement statement = this.statementFunction.apply(connection, sql);
return Flux.from(this.filterFunction.filter(statement, this.executeFunction))
.cast(Result.class).checkpoint("SQL \"" + sql + "\" [DatabaseClient]");
}

@Nullable
@Override
public String getSql() {
return this.resolvedSql;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.r2dbc.core;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
Expand Down Expand Up @@ -64,6 +66,7 @@
* @author Mark Paluch
* @author Ferdinand Jacobs
* @author Jens Schauder
* @author Simon Baslé
*/
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
Expand Down Expand Up @@ -397,6 +400,47 @@ void shouldApplySimpleStatementFilterFunctions() {
inOrder.verifyNoMoreInteractions();
}

@Test
void sqlSupplierInvocationIsDeferredUntilSubscription() {
// We'll have either 2 or 3 rows, depending on the subscription and the generated SQL
MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(
MockColumnMetadata.builder().name("id").javaType(Integer.class).build()).build();
final MockRow row1 = MockRow.builder().identified("id", Integer.class, 1).build();
final MockRow row2 = MockRow.builder().identified("id", Integer.class, 2).build();
final MockRow row3 = MockRow.builder().identified("id", Integer.class, 3).build();
// Set up 2 mock statements
mockStatementFor("SELECT id FROM test WHERE id < '3'", MockResult.builder()
.rowMetadata(metadata)
.row(row1, row2).build());
mockStatementFor("SELECT id FROM test WHERE id < '4'", MockResult.builder()
.rowMetadata(metadata)
.row(row1, row2, row3).build());
// Create the client
DatabaseClient databaseClient = this.databaseClientBuilder.build();

AtomicInteger invoked = new AtomicInteger();
// Assemble a publisher, but don't subscribe yet
Mono<List<Integer>> operation = databaseClient
.sql(() -> {
int idMax = 2 + invoked.incrementAndGet();
return String.format("SELECT id FROM test WHERE id < '%s'", idMax);
})
.map(r -> r.get("id", Integer.class))
.all()
.collectList();

assertThat(invoked).as("invoked (before subscription)").hasValue(0);

List<Integer> rows = operation.block();
assertThat(invoked).as("invoked (after 1st subscription)").hasValue(1);
assertThat(rows).containsExactly(1, 2);

rows = operation.block();
assertThat(invoked).as("invoked (after 2nd subscription)").hasValue(2);
assertThat(rows).containsExactly(1, 2, 3);
}


private Statement mockStatement() {
return mockStatementFor(null, null);
}
Expand Down

0 comments on commit de53d77

Please sign in to comment.