From 4bb0f041809f0a66df6b9086fa4441ef0647ee49 Mon Sep 17 00:00:00 2001 From: dev-jonghoonpark Date: Sat, 1 Feb 2025 00:42:53 +0900 Subject: [PATCH] impl: Add @Sql annotation support for R2DBC in Spring tests Signed-off-by: dev-jonghoonpark --- spring-test/spring-test.gradle | 2 + .../context/jdbc/R2dbcPopulatorUtils.java | 62 ++++++++++ .../jdbc/SqlScriptsTestExecutionListener.java | 11 +- .../TestContextReactiveTransactionUtils.java | 109 ++++++++++++++++++ .../test/r2dbc/R2dbcTestUtils.java | 86 ++++++++++++++ .../test/r2dbc/package-info.java | 7 ++ .../R2dbcSqlScriptsSpringJupiterTests.java | 56 +++++++++ .../reactive/EmptyReactiveDatabaseConfig.java | 46 ++++++++ ...R2dbcSqlScriptsSpringJupiterTests.test.sql | 1 + .../test/context/r2dbc/schema.sql | 4 + 10 files changed, 382 insertions(+), 2 deletions(-) create mode 100644 spring-test/src/main/java/org/springframework/test/context/jdbc/R2dbcPopulatorUtils.java create mode 100644 spring-test/src/main/java/org/springframework/test/context/transaction/reactive/TestContextReactiveTransactionUtils.java create mode 100644 spring-test/src/main/java/org/springframework/test/r2dbc/R2dbcTestUtils.java create mode 100644 spring-test/src/main/java/org/springframework/test/r2dbc/package-info.java create mode 100644 spring-test/src/test/java/org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.java create mode 100644 spring-test/src/test/java/org/springframework/test/context/reactive/EmptyReactiveDatabaseConfig.java create mode 100644 spring-test/src/test/resources/org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.test.sql create mode 100644 spring-test/src/test/resources/org/springframework/test/context/r2dbc/schema.sql diff --git a/spring-test/spring-test.gradle b/spring-test/spring-test.gradle index f703232a89df..7f6401aa4168 100644 --- a/spring-test/spring-test.gradle +++ b/spring-test/spring-test.gradle @@ -8,6 +8,7 @@ dependencies { optional(project(":spring-beans")) optional(project(":spring-context")) optional(project(":spring-jdbc")) + optional(project(":spring-r2dbc")) optional(project(":spring-orm")) optional(project(":spring-tx")) optional(project(":spring-web")) @@ -80,6 +81,7 @@ dependencies { testImplementation("org.hibernate.orm:hibernate-core") testImplementation("org.hibernate.validator:hibernate-validator") testImplementation("org.hsqldb:hsqldb") + testImplementation("io.r2dbc:r2dbc-h2") testImplementation("org.junit.platform:junit-platform-testkit") testRuntimeOnly("com.sun.xml.bind:jaxb-core") testRuntimeOnly("com.sun.xml.bind:jaxb-impl") diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/R2dbcPopulatorUtils.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/R2dbcPopulatorUtils.java new file mode 100644 index 000000000000..6e5c8f01856b --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/R2dbcPopulatorUtils.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.jdbc; + +import java.util.List; + +import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.r2dbc.connection.init.ResourceDatabasePopulator; + +/** + * R2dbcPopulatorUtils is a separate class to avoid name conflicts with existing + * jdbc-related classes. + * + *

NOTE: In the current architecture, MergedSqlConfig is implemented + * as a package-private method, so it has been placed in + * org.springframework.test.context.jdbc. + * + * @author jonghoon park + * @since 7.0 + * @see SqlScriptsTestExecutionListener + * @see MergedSqlConfig + */ +public class R2dbcPopulatorUtils { + + static void execute(MergedSqlConfig mergedSqlConfig, ConnectionFactory connectionFactory, List scriptResources) { + ResourceDatabasePopulator populator = createResourceDatabasePopulator(mergedSqlConfig); + populator.setScripts(scriptResources.toArray(new Resource[0])); + + Mono.from(connectionFactory.create()) + .flatMap(populator::populate) + .block(); + } + + private static ResourceDatabasePopulator createResourceDatabasePopulator(MergedSqlConfig mergedSqlConfig) { + ResourceDatabasePopulator populator = new ResourceDatabasePopulator(); + populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding()); + populator.setSeparator(mergedSqlConfig.getSeparator()); + populator.setCommentPrefixes(mergedSqlConfig.getCommentPrefixes()); + populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter()); + populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter()); + populator.setContinueOnError(mergedSqlConfig.getErrorMode() == SqlConfig.ErrorMode.CONTINUE_ON_ERROR); + populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == SqlConfig.ErrorMode.IGNORE_FAILED_DROPS); + return populator; + } +} diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java index 5e741cec4a9f..86f360d257d8 100644 --- a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java @@ -23,6 +23,7 @@ import java.util.stream.Stream; import javax.sql.DataSource; +import io.r2dbc.spi.ConnectionFactory; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -45,6 +46,7 @@ import org.springframework.test.context.jdbc.SqlMergeMode.MergeMode; import org.springframework.test.context.support.AbstractTestExecutionListener; import org.springframework.test.context.transaction.TestContextTransactionUtils; +import org.springframework.test.context.transaction.reactive.TestContextReactiveTransactionUtils; import org.springframework.test.context.util.TestContextResourceUtils; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.TransactionDefinition; @@ -332,8 +334,13 @@ else if (logger.isDebugEnabled()) { Assert.state(!newTxRequired, () -> String.format("Failed to execute SQL scripts for test context %s: " + "cannot execute SQL scripts using Transaction Mode " + "[%s] without a PlatformTransactionManager.", testContext, TransactionMode.ISOLATED)); - Assert.state(dataSource != null, () -> String.format("Failed to execute SQL scripts for test context %s: " + - "supply at least a DataSource or PlatformTransactionManager.", testContext)); + if (dataSource == null) { + ConnectionFactory connectionFactory = TestContextReactiveTransactionUtils.retrieveConnectionFactory(testContext); + Assert.state(connectionFactory != null, () -> String.format("Failed to execute SQL scripts for test context %s: " + + "supply at least a DataSource or PlatformTransactionManager or ConnectionFactory.", testContext)); + R2dbcPopulatorUtils.execute(mergedSqlConfig, connectionFactory, scriptResources); + return; + } // Execute scripts directly against the DataSource populator.execute(dataSource); } diff --git a/spring-test/src/main/java/org/springframework/test/context/transaction/reactive/TestContextReactiveTransactionUtils.java b/spring-test/src/main/java/org/springframework/test/context/transaction/reactive/TestContextReactiveTransactionUtils.java new file mode 100644 index 000000000000..a000507f77dd --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/context/transaction/reactive/TestContextReactiveTransactionUtils.java @@ -0,0 +1,109 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.reactive; + +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.jspecify.annotations.Nullable; +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.test.context.TestContext; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.util.Assert; + +/** + * Utility methods for working with transactions and data access related beans + * within the Spring TestContext Framework. + * + *

Mainly for internal use within the framework. + * + * @author jonghoon park + * @since 7.0 + */ +public class TestContextReactiveTransactionUtils { + + /** + * Default bean name for a {@link ConnectionFactory}: + * {@code "connectionFactory"}. + */ + public static final String DEFAULT_CONNECTION_FACTORY_NAME = "connectionFactory"; + + + private static final Log logger = LogFactory.getLog(TestContextReactiveTransactionUtils.class); + + /** + * Retrieve the {@link ConnectionFactory} to use for the supplied {@linkplain TestContext + * test context}. + *

The following algorithm is used to retrieve the {@code ConnectionFactory} from + * the {@link org.springframework.context.ApplicationContext ApplicationContext} + * of the supplied test context: + *

    + *
  1. Attempt to look up the single {@code ConnectionFactory} by type. + *
  2. Attempt to look up the primary {@code ConnectionFactory} by type. + *
  3. Attempt to look up the {@code ConnectionFactory} by type and the + * {@linkplain #DEFAULT_CONNECTION_FACTORY_NAME default data source name}. + *
+ * @param testContext the test context for which the {@code ConnectionFactory} + * should be retrieved; never {@code null} + * @return the {@code DataSource} to use, or {@code null} if not found + */ + @Nullable + public static ConnectionFactory retrieveConnectionFactory(TestContext testContext) { + Assert.notNull(testContext, "TestContext must not be null"); + BeanFactory bf = testContext.getApplicationContext().getAutowireCapableBeanFactory(); + + try { + if (bf instanceof ListableBeanFactory lbf) { + // Look up single bean by type + Map ConnectionFactories = + BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, ConnectionFactory.class); + if (ConnectionFactories.size() == 1) { + return ConnectionFactories.values().iterator().next(); + } + + try { + // look up single bean by type, with support for 'primary' beans + return bf.getBean(ConnectionFactory.class); + } + catch (BeansException ex) { + logBeansException(testContext, ex, PlatformTransactionManager.class); + } + } + + // look up by type and default name + return bf.getBean(DEFAULT_CONNECTION_FACTORY_NAME, ConnectionFactory.class); + } + catch (BeansException ex) { + logBeansException(testContext, ex, Connection.class); + return null; + } + } + + private static void logBeansException(TestContext testContext, BeansException ex, Class beanType) { + if (logger.isTraceEnabled()) { + logger.trace("Caught exception while retrieving %s for test context %s" + .formatted(beanType.getSimpleName(), testContext), ex); + } + } +} diff --git a/spring-test/src/main/java/org/springframework/test/r2dbc/R2dbcTestUtils.java b/spring-test/src/main/java/org/springframework/test/r2dbc/R2dbcTestUtils.java new file mode 100644 index 000000000000..1169064658eb --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/r2dbc/R2dbcTestUtils.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.r2dbc; + +import java.util.Objects; + +import org.jspecify.annotations.Nullable; +import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Mono; + +import org.springframework.r2dbc.core.DatabaseClient; +import org.springframework.util.StringUtils; + +/** + * {@code R2dbcTestUtils} is a collection of R2DBC related utility functions + * intended to simplify standard database testing scenarios. + * + * @author jonghoon park + * @since 7.0 + * @see org.springframework.r2dbc.core.DatabaseClient + */ +public abstract class R2dbcTestUtils { + + /** + * Count the rows in the given table. + * @param connectionFactory the {@link ConnectionFactory} with which to perform R2DBC + * operations + * @param tableName name of the table to count rows in + * @return the number of rows in the table + */ + public static Mono countRowsInTable(ConnectionFactory connectionFactory, String tableName) { + return countRowsInTable(DatabaseClient.create(connectionFactory), tableName); + } + + /** + * Count the rows in the given table. + * @param databaseClient the {@link DatabaseClient} with which to perform R2DBC + * operations + * @param tableName name of the table to count rows in + * @return the number of rows in the table + */ + public static Mono countRowsInTable(DatabaseClient databaseClient, String tableName) { + return countRowsInTableWhere(databaseClient, tableName, null); + } + + /** + * Count the rows in the given table, using the provided {@code WHERE} clause. + *

If the provided {@code WHERE} clause contains text, it will be prefixed + * with {@code " WHERE "} and then appended to the generated {@code SELECT} + * statement. For example, if the provided table name is {@code "person"} and + * the provided where clause is {@code "name = 'Bob' and age > 25"}, the + * resulting SQL statement to execute will be + * {@code "SELECT COUNT(0) FROM person WHERE name = 'Bob' and age > 25"}. + * @param databaseClient the {@link DatabaseClient} with which to perform JDBC + * operations + * @param tableName the name of the table to count rows in + * @param whereClause the {@code WHERE} clause to append to the query + * @return the number of rows in the table that match the provided + * {@code WHERE} clause + */ + public static Mono countRowsInTableWhere( + DatabaseClient databaseClient, String tableName, @Nullable String whereClause) { + + String sql = "SELECT COUNT(0) FROM " + tableName; + if (StringUtils.hasText(whereClause)) { + sql += " WHERE " + whereClause; + } + return databaseClient.sql(sql) + .map(row -> Objects.requireNonNull(row.get(0, Long.class)).intValue()) + .one(); + } +} diff --git a/spring-test/src/main/java/org/springframework/test/r2dbc/package-info.java b/spring-test/src/main/java/org/springframework/test/r2dbc/package-info.java new file mode 100644 index 000000000000..825098dcf7fb --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/r2dbc/package-info.java @@ -0,0 +1,7 @@ +/** + * Support classes for tests based on R2DBC. + */ +@NullMarked +package org.springframework.test.r2dbc; + +import org.jspecify.annotations.NullMarked; diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.java b/spring-test/src/test/java/org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.java new file mode 100644 index 000000000000..1417c36cd3ea --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.aot.samples.r2dbc; + +import io.r2dbc.spi.ConnectionFactory; +import reactor.test.StepVerifier; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.TestPropertySource; +import org.springframework.test.context.reactive.EmptyReactiveDatabaseConfig; +import org.springframework.test.context.jdbc.Sql; +import org.springframework.test.context.jdbc.SqlMergeMode; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.transaction.annotation.Transactional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.MERGE; +import static org.springframework.test.r2dbc.R2dbcTestUtils.countRowsInTable; + +/** + * @author jonghoon park + * @since 7.0 + */ +@SpringJUnitConfig(EmptyReactiveDatabaseConfig.class) +@Transactional +@SqlMergeMode(MERGE) +@Sql("/org/springframework/test/context/r2dbc/schema.sql") +@DirtiesContext +@TestPropertySource(properties = "test.engine = jupiter") +public class R2dbcSqlScriptsSpringJupiterTests { + + @Test + @Sql // default script --> org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.test.sql + void test(@Autowired ConnectionFactory connectionFactory) { + StepVerifier.create(countRowsInTable(connectionFactory, "users")) + .assertNext(count -> assertThat(count).isEqualTo(1)) + .verifyComplete(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/reactive/EmptyReactiveDatabaseConfig.java b/spring-test/src/test/java/org/springframework/test/context/reactive/EmptyReactiveDatabaseConfig.java new file mode 100644 index 000000000000..1e8f762ba6f5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/reactive/EmptyReactiveDatabaseConfig.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.reactive; + +import io.r2dbc.spi.ConnectionFactory; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.r2dbc.connection.R2dbcTransactionManager; +import org.springframework.r2dbc.connection.SingleConnectionFactory; +import org.springframework.transaction.ReactiveTransactionManager; + +/** + * Empty reactive database configuration class for SQL script integration tests. + * + * @author jonghoon park + * @since 7.0 + */ +@Configuration +public class EmptyReactiveDatabaseConfig { + + @Bean + ConnectionFactory connectionFactory() { + SingleConnectionFactory factory = new SingleConnectionFactory("r2dbc:h2:mem:///testdb?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE;DATABASE_TO_LOWER=TRUE", false); + return factory.unwrap(); + } + + @Bean + ReactiveTransactionManager transactionManager(ConnectionFactory connectionFactory) { + return new R2dbcTransactionManager(connectionFactory); + } +} diff --git a/spring-test/src/test/resources/org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.test.sql b/spring-test/src/test/resources/org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.test.sql new file mode 100644 index 000000000000..8c5bb0587d7a --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.test.sql @@ -0,0 +1 @@ +INSERT INTO users VALUES('Daisy'); diff --git a/spring-test/src/test/resources/org/springframework/test/context/r2dbc/schema.sql b/spring-test/src/test/resources/org/springframework/test/context/r2dbc/schema.sql new file mode 100644 index 000000000000..2068a019d601 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/r2dbc/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE users ( + name VARCHAR(20) NOT NULL, + PRIMARY KEY(name) +);