diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java b/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java index 02f0179845..8cc27aa46f 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java @@ -16,6 +16,8 @@ package org.springframework.batch.core; +import java.util.Set; + /** * Enumeration representing the status of an execution. * @@ -71,6 +73,8 @@ public enum BatchStatus { */ UNKNOWN; + public static final Set RUNNING_STATUSES = Set.of(STARTING, STARTED, STOPPING); + /** * Convenience method to return the higher value status of the statuses passed to the * method. @@ -87,7 +91,7 @@ public static BatchStatus max(BatchStatus status1, BatchStatus status2) { * @return true if the status is STARTING, STARTED, STOPPING */ public boolean isRunning() { - return this == STARTING || this == STARTED || this == STOPPING; + return RUNNING_STATUSES.contains(this); } /** diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java index 85c69655f9..17ad7f521a 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Set; +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.JobParameters; @@ -87,6 +88,14 @@ default JobInstance getLastJobInstance(String jobName) { @Nullable StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable Long stepExecutionId); + /** + * Find {@link StepExecution}s by IDs and parent {@link JobExecution} ID + * @param jobExecutionId given job execution id + * @param stepExecutionIds given step execution ids + * @return collection of {@link StepExecution} + */ + Set getStepExecutions(Long jobExecutionId, Set stepExecutionIds); + /** * @param instanceId {@link Long} The ID for the {@link JobInstance} to obtain. * @return the {@code JobInstance} that has this ID, or {@code null} if not found. @@ -170,4 +179,13 @@ default JobExecution getLastJobExecution(JobInstance jobInstance) { */ long getJobInstanceCount(@Nullable String jobName) throws NoSuchJobException; + /** + * Retrieve number of step executions that match the step execution ids and the batch + * statuses + * @param stepExecutionIds given step execution ids + * @param matchingBatchStatuses given batch statuses to match against + * @return number of {@link StepExecution} matching the criteria + */ + long getStepExecutionCount(Set stepExecutionIds, Set matchingBatchStatuses); + } diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java index 236be9902d..89fcd3368e 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java @@ -16,6 +16,7 @@ package org.springframework.batch.core.explore.support; +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.JobParameters; @@ -147,6 +148,19 @@ public StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable L return stepExecution; } + @Nullable + @Override + public Set getStepExecutions(Long jobExecutionId, Set stepExecutionIds) { + JobExecution jobExecution = jobExecutionDao.getJobExecution(jobExecutionId); + if (jobExecution == null) { + return null; + } + getJobExecutionDependencies(jobExecution); + Set stepExecutions = stepExecutionDao.getStepExecutions(jobExecution, stepExecutionIds); + stepExecutions.forEach(this::getStepExecutionDependencies); + return stepExecutions; + } + @Nullable @Override public JobInstance getJobInstance(@Nullable Long instanceId) { @@ -180,6 +194,14 @@ public long getJobInstanceCount(@Nullable String jobName) throws NoSuchJobExcept return jobInstanceDao.getJobInstanceCount(jobName); } + @Override + public long getStepExecutionCount(Set stepExecutionIds, Set matchingBatchStatuses) { + if (stepExecutionIds.isEmpty() || matchingBatchStatuses.isEmpty()) { + return 0; + } + return stepExecutionDao.countStepExecutions(stepExecutionIds, matchingBatchStatuses); + } + /** * @return instance of {@link JobInstanceDao}. * @since 5.1 diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/AbstractJdbcBatchMetadataDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/AbstractJdbcBatchMetadataDao.java index b755651fb5..3129552282 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/AbstractJdbcBatchMetadataDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/AbstractJdbcBatchMetadataDao.java @@ -17,6 +17,9 @@ package org.springframework.batch.core.repository.dao; import java.sql.Types; +import java.util.Collection; +import java.util.Map; +import java.util.stream.Collectors; import org.springframework.beans.factory.InitializingBean; import org.springframework.jdbc.core.JdbcOperations; @@ -51,6 +54,14 @@ protected String getQuery(String base) { return StringUtils.replace(base, "%PREFIX%", tablePrefix); } + protected String getQuery(String base, Map> collectionParams) { + String query = getQuery(base); + for (Map.Entry> collectionParam : collectionParams.entrySet()) { + query = createParameterizedQuery(query, collectionParam.getKey(), collectionParam.getValue()); + } + return query; + } + protected String getTablePrefix() { return tablePrefix; } @@ -80,6 +91,18 @@ public void setClobTypeToUse(int clobTypeToUse) { this.clobTypeToUse = clobTypeToUse; } + /** + * Replaces a given placeholder with a number of parameters (i.e. "?"). + * @param sqlTemplate given sql template + * @param placeholder placeholder that is being used for parameters + * @param parameters collection of parameters with variable size + * @return sql query replaced with a number of parameters + */ + private static String createParameterizedQuery(String sqlTemplate, String placeholder, Collection parameters) { + String params = parameters.stream().map(p -> "?").collect(Collectors.joining(", ")); + return sqlTemplate.replace(placeholder, params); + } + @Override public void afterPropertiesSet() throws Exception { Assert.state(jdbcTemplate != null, "JdbcOperations is required"); diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java index b1e46e0c23..90c03e70e3 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java @@ -27,8 +27,11 @@ import java.util.Comparator; import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Stream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -93,6 +96,16 @@ public class JdbcStepExecutionDao extends AbstractJdbcBatchMetadataDao implement private static final String GET_STEP_EXECUTION = GET_RAW_STEP_EXECUTIONS + " AND STEP_EXECUTION_ID = ?"; + private static final String GET_STEP_EXECUTIONS_BY_IDS = GET_RAW_STEP_EXECUTIONS + + " and STEP_EXECUTION_ID IN (%STEP_EXECUTION_IDS%)"; + + private static final String COUNT_STEP_EXECUTIONS_BY_IDS_AND_STATUSES = """ + SELECT COUNT(*) + FROM %PREFIX%STEP_EXECUTION SE + WHERE SE.STEP_EXECUTION_ID IN (%STEP_EXECUTION_IDS%) + AND SE.STATUS IN (%STEP_STATUSES%) + """; + private static final String GET_LAST_STEP_EXECUTION = """ SELECT SE.STEP_EXECUTION_ID, SE.STEP_NAME, SE.START_TIME, SE.END_TIME, SE.STATUS, SE.COMMIT_COUNT, SE.READ_COUNT, SE.FILTER_COUNT, SE.WRITE_COUNT, SE.EXIT_CODE, SE.EXIT_MESSAGE, SE.READ_SKIP_COUNT, SE.WRITE_SKIP_COUNT, SE.PROCESS_SKIP_COUNT, SE.ROLLBACK_COUNT, SE.LAST_UPDATED, SE.VERSION, SE.CREATE_TIME, JE.JOB_EXECUTION_ID, JE.START_TIME, JE.END_TIME, JE.STATUS, JE.EXIT_CODE, JE.EXIT_MESSAGE, JE.CREATE_TIME, JE.LAST_UPDATED, JE.VERSION FROM %PREFIX%JOB_EXECUTION JE @@ -337,6 +350,16 @@ public StepExecution getStepExecution(JobExecution jobExecution, Long stepExecut } } + @Override + @Nullable + public Set getStepExecutions(JobExecution jobExecution, Set stepExecutionIds) { + List executions = getJdbcTemplate().query( + getQuery(GET_STEP_EXECUTIONS_BY_IDS, Map.of("%STEP_EXECUTION_IDS%", stepExecutionIds)), + new StepExecutionRowMapper(jobExecution), + Stream.concat(Stream.of(jobExecution.getId()), stepExecutionIds.stream()).toArray(Object[]::new)); + return Set.copyOf(executions); + } + @Override public StepExecution getLastStepExecution(JobInstance jobInstance, String stepName) { List executions = getJdbcTemplate().query(getQuery(GET_LAST_STEP_EXECUTION), (rs, rowNum) -> { @@ -360,6 +383,16 @@ public StepExecution getLastStepExecution(JobInstance jobInstance, String stepNa } } + @Override + public long countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses) { + return getJdbcTemplate().queryForObject( + getQuery(COUNT_STEP_EXECUTIONS_BY_IDS_AND_STATUSES, + Map.of("%STEP_EXECUTION_IDS%", stepExecutionIds, "%STEP_STATUSES%", matchingBatchStatuses)), + Long.class, + Stream.concat(stepExecutionIds.stream(), matchingBatchStatuses.stream().map(BatchStatus::name)) + .toArray(Object[]::new)); + } + @Override public void addStepExecutions(JobExecution jobExecution) { getJdbcTemplate().query(getQuery(GET_STEP_EXECUTIONS), new StepExecutionRowMapper(jobExecution), diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MongoStepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MongoStepExecutionDao.java index 9b889c1d81..b1a7449242 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MongoStepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MongoStepExecutionDao.java @@ -20,7 +20,10 @@ import java.util.Comparator; import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.StepExecution; @@ -95,6 +98,17 @@ public StepExecution getStepExecution(JobExecution jobExecution, Long stepExecut return stepExecution != null ? this.stepExecutionConverter.toStepExecution(stepExecution, jobExecution) : null; } + @Override + public Set getStepExecutions(JobExecution jobExecution, Set stepExecutionIds) { + Query query = query(where("stepExecutionId").in(stepExecutionIds)); + List stepExecutions = this.mongoOperations + .find(query, org.springframework.batch.core.repository.persistence.StepExecution.class, + STEP_EXECUTIONS_COLLECTION_NAME); + return stepExecutions.stream() + .map(stepExecution -> this.stepExecutionConverter.toStepExecution(stepExecution, jobExecution)) + .collect(Collectors.toSet()); + } + @Override public StepExecution getLastStepExecution(JobInstance jobInstance, String stepName) { // TODO optimize the query @@ -160,4 +174,12 @@ public long countStepExecutions(JobInstance jobInstance, String stepName) { return count; } + @Override + public long countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses) { + Query query = query(where("jobExecutionId").is(stepExecutionIds).and("status").in(matchingBatchStatuses)); + return this.mongoOperations.count(query, + org.springframework.batch.core.repository.persistence.StepExecution.class, + STEP_EXECUTIONS_COLLECTION_NAME); + } + } diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java index 58e43bd8ef..00af0487d0 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java @@ -17,7 +17,9 @@ package org.springframework.batch.core.repository.dao; import java.util.Collection; +import java.util.Set; +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.StepExecution; @@ -62,6 +64,15 @@ public interface StepExecutionDao { @Nullable StepExecution getStepExecution(JobExecution jobExecution, Long stepExecutionId); + /** + * Get a collection of {@link StepExecution} matching job execution and step execution + * ids. + * @param jobExecution the parent job execution + * @param stepExecutionIds the step execution ids + * @return collection of {@link StepExecution} + */ + Set getStepExecutions(JobExecution jobExecution, Set stepExecutionIds); + /** * Retrieve the last {@link StepExecution} for a given {@link JobInstance} ordered by * creation time and then id. @@ -91,6 +102,15 @@ default long countStepExecutions(JobInstance jobInstance, String stepName) { throw new UnsupportedOperationException(); } + /** + * Count {@link StepExecution} that match the ids and statuses of them - avoid loading + * them into memory + * @param stepExecutionIds given step execution ids + * @param matchingBatchStatuses + * @return the count of matching steps + */ + long countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses); + /** * Delete the given step execution. * @param stepExecution the step execution to delete diff --git a/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java b/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java index 625145f0e2..64687af28f 100644 --- a/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java +++ b/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java @@ -551,6 +551,11 @@ public StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable L throw new UnsupportedOperationException(); } + @Override + public Set getStepExecutions(Long jobExecutionId, Set stepExecutionIds) { + return Set.of(); + } + @Override public List getJobNames() { throw new UnsupportedOperationException(); @@ -579,6 +584,11 @@ public long getJobInstanceCount(@Nullable String jobName) throws NoSuchJobExcept } } + @Override + public long getStepExecutionCount(Set stepExecutionIds, Set matchingBatchStatuses) { + return 0; + } + } public static class StubJobParametersConverter implements JobParametersConverter { diff --git a/spring-batch-core/src/test/java/org/springframework/batch/core/repository/support/SimpleJobRepositoryTests.java b/spring-batch-core/src/test/java/org/springframework/batch/core/repository/support/SimpleJobRepositoryTests.java index 5b903d0f14..57e8fabab5 100644 --- a/spring-batch-core/src/test/java/org/springframework/batch/core/repository/support/SimpleJobRepositoryTests.java +++ b/spring-batch-core/src/test/java/org/springframework/batch/core/repository/support/SimpleJobRepositoryTests.java @@ -202,7 +202,7 @@ void testSaveStepExecutionSetsLastUpdated() { assertNotNull(stepExecution.getLastUpdated()); LocalDateTime lastUpdated = stepExecution.getLastUpdated(); - assertTrue(lastUpdated.isAfter(before)); + assertFalse(lastUpdated.isBefore(before)); } @Test @@ -236,7 +236,7 @@ void testUpdateStepExecutionSetsLastUpdated() { assertNotNull(stepExecution.getLastUpdated()); LocalDateTime lastUpdated = stepExecution.getLastUpdated(); - assertTrue(lastUpdated.isAfter(before)); + assertFalse(lastUpdated.isBefore(before)); } @Test diff --git a/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java b/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java index f0c710c544..3e4cac2100 100644 --- a/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java +++ b/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java @@ -28,7 +28,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.batch.core.JobExecution; +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.Step; import org.springframework.batch.core.StepExecution; import org.springframework.batch.core.explore.JobExplorer; @@ -251,25 +251,20 @@ protected Set doHandle(StepExecution managerStepExecution, private Set pollReplies(final StepExecution managerStepExecution, final Set split) throws Exception { - Set partitionStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet()); Callable> callback = () -> { - JobExecution jobExecution = jobExplorer.getJobExecution(managerStepExecution.getJobExecutionId()); - Set finishedStepExecutions = jobExecution.getStepExecutions() - .stream() - .filter(stepExecution -> partitionStepExecutionIds.contains(stepExecution.getId())) - .filter(stepExecution -> !stepExecution.getStatus().isRunning()) - .collect(Collectors.toSet()); - - if (logger.isDebugEnabled()) { - logger.debug(String.format("Currently waiting on %s partitions to finish", split.size())); - } - - if (finishedStepExecutions.size() == split.size()) { - return finishedStepExecutions; + Set currentStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet()); + long runningStepExecutions = jobExplorer.getStepExecutionCount(currentStepExecutionIds, + BatchStatus.RUNNING_STATUSES); + if (runningStepExecutions > 0 && !split.isEmpty()) { + if (logger.isDebugEnabled()) { + logger.debug(String.format("Currently waiting on %s out of %s partitions to finish", + runningStepExecutions, split.size())); + } + return null; } else { - return null; + return jobExplorer.getStepExecutions(managerStepExecution.getJobExecutionId(), currentStepExecutionIds); } }; diff --git a/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java b/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java index 4f7b677649..3a07c4de22 100644 --- a/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java +++ b/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java @@ -21,6 +21,8 @@ import java.util.Collections; import java.util.HashSet; import java.util.concurrent.TimeoutException; +import java.util.Set; +import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -181,6 +183,11 @@ void testHandleWithJobRepositoryPolling() throws Exception { completedJobExecution.addStepExecutions(Arrays.asList(partition2, partition1, partition4)); when(jobExplorer.getJobExecution(5L)).thenReturn(runningJobExecution, runningJobExecution, runningJobExecution, completedJobExecution); + Set stepExecutionIds = stepExecutions.stream().map(StepExecution::getId).collect(Collectors.toSet()); + when(jobExplorer.getStepExecutionCount(stepExecutionIds, BatchStatus.RUNNING_STATUSES)).thenReturn(3L, 2L, 1L, + 0L); + Set completedStepExecutions = Set.of(partition2, partition1, partition4); + when(jobExplorer.getStepExecutions(jobExecution.getId(), stepExecutionIds)).thenReturn(completedStepExecutions); // set messageChannelPartitionHandler.setMessagingOperations(operations); @@ -200,6 +207,8 @@ void testHandleWithJobRepositoryPolling() throws Exception { assertTrue(executions.contains(partition4)); // verify + verify(jobExplorer, times(4)).getStepExecutionCount(stepExecutionIds, BatchStatus.RUNNING_STATUSES); + verify(jobExplorer, times(1)).getStepExecutions(jobExecution.getId(), stepExecutionIds); verify(operations, times(3)).send(any(Message.class)); } @@ -228,6 +237,8 @@ void testHandleWithJobRepositoryPollingTimeout() throws Exception { JobExecution runningJobExecution = new JobExecution(5L, new JobParameters()); runningJobExecution.addStepExecutions(Arrays.asList(partition2, partition1, partition3)); when(jobExplorer.getJobExecution(5L)).thenReturn(runningJobExecution); + Set stepExecutionIds = stepExecutions.stream().map(StepExecution::getId).collect(Collectors.toSet()); + when(jobExplorer.getStepExecutionCount(stepExecutionIds, BatchStatus.RUNNING_STATUSES)).thenReturn(1L); // set messageChannelPartitionHandler.setMessagingOperations(operations);