Skip to content

Commit

Permalink
Generate reflection hints for main methods
Browse files Browse the repository at this point in the history
This commit makes sure to register the necessary hints to invoke the
main method of any bean available in the context. This is necessary
for tests that use the UseMainMethod feature.

This generates more hints than strictly necessary as there isn't a
way to contribute hints based on a ContextLoader, see
spring-projects/spring-framework#34513 for
more details.

Closes gh-44461
  • Loading branch information
snicoll committed Feb 28, 2025
1 parent c91c8e2 commit e1f45c5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2024 the original author or authors.
* Copyright 2012-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,11 +19,19 @@
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Consumer;

import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.aot.hint.ReflectionHints;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.boot.ApplicationContextFactory;
import org.springframework.boot.Banner;
import org.springframework.boot.ConfigurableBootstrapContext;
Expand Down Expand Up @@ -158,20 +166,23 @@ private Method getMainMethod(MergedContextConfiguration mergedConfig, UseMainMet
.orElse(null);
Assert.state(springBootConfiguration != null || useMainMethod == UseMainMethod.WHEN_AVAILABLE,
"Cannot use main method as no @SpringBootConfiguration-annotated class is available");
Method mainMethod = (springBootConfiguration != null)
? ReflectionUtils.findMethod(springBootConfiguration, "main", String[].class) : null;
Method mainMethod = findMainMethod(springBootConfiguration);
Assert.state(mainMethod != null || useMainMethod == UseMainMethod.WHEN_AVAILABLE,
() -> "Main method not found on '%s'".formatted(springBootConfiguration.getName()));
return mainMethod;
}

private static Method findMainMethod(Class<?> type) {
Method mainMethod = (type != null) ? ReflectionUtils.findMethod(type, "main", String[].class) : null;
if (mainMethod == null && KotlinDetector.isKotlinPresent()) {
try {
Class<?> kotlinClass = ClassUtils.forName(springBootConfiguration.getName() + "Kt",
springBootConfiguration.getClassLoader());
Class<?> kotlinClass = ClassUtils.forName(type.getName() + "Kt", type.getClassLoader());
mainMethod = ReflectionUtils.findMethod(kotlinClass, "main", String[].class);
}
catch (ClassNotFoundException ex) {
// Ignore
}
}
Assert.state(mainMethod != null || useMainMethod == UseMainMethod.WHEN_AVAILABLE,
() -> "Main method not found on '%s'".formatted(springBootConfiguration.getName()));
return mainMethod;
}

Expand Down Expand Up @@ -574,4 +585,39 @@ private ApplicationContext run(ThrowingSupplier<ConfigurableApplicationContext>

}

static class MainMethodBeanFactoryInitializationAotProcessor implements BeanFactoryInitializationAotProcessor {

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(
ConfigurableListableBeanFactory beanFactory) {
List<Method> mainMethods = new ArrayList<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
Class<?> beanType = beanFactory.getType(beanName);
Method mainMethod = findMainMethod(beanType);
if (mainMethod != null) {
mainMethods.add(mainMethod);
}
}
return !mainMethods.isEmpty() ? new AotContribution(mainMethods) : null;
}

static class AotContribution implements BeanFactoryInitializationAotContribution {

private final Collection<Method> mainMethods;

AotContribution(Collection<Method> mainMethods) {
this.mainMethods = mainMethods;
}

@Override
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {
ReflectionHints reflectionHints = generationContext.getRuntimeHints().reflection();
this.mainMethods.forEach((method) -> reflectionHints.registerMethod(method, ExecutableMode.INVOKE));
}

}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor=\
org.springframework.boot.test.context.SpringBootContextLoader.MainMethodBeanFactoryInitializationAotProcessor
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2024 the original author or authors.
* Copyright 2012-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,10 +25,16 @@
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.boot.ApplicationContextFactory;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootContextLoader.MainMethodBeanFactoryInitializationAotProcessor;
import org.springframework.boot.test.context.SpringBootTest.UseMainMethod;
import org.springframework.boot.test.util.TestPropertyValues;
import org.springframework.boot.web.reactive.context.GenericReactiveWebApplicationContext;
Expand Down Expand Up @@ -248,6 +254,35 @@ void whenUseMainMethodWithContextHierarchyThrowsException() {
.withMessage("UseMainMethod.ALWAYS cannot be used with @ContextHierarchy tests");
}

@Test
void whenMainMethodPresentRegisterReflectionHints() {
TestContext testContext = new ExposedTestContextManager(UseMainMethodWhenAvailableAndNoMainMethod.class)
.getExposedTestContext();
ConfigurableListableBeanFactory beanFactory = (ConfigurableListableBeanFactory) testContext
.getApplicationContext()
.getAutowireCapableBeanFactory();
BeanFactoryInitializationAotContribution aotContribution = new MainMethodBeanFactoryInitializationAotProcessor()
.processAheadOfTime(beanFactory);
assertThat(aotContribution).isNull();
}

@Test
void whenMainMethodNotAvailableReturnsNoAotContribution() {
TestContext testContext = new ExposedTestContextManager(UseMainMethodWhenAvailableAndMainMethod.class)
.getExposedTestContext();
ConfigurableListableBeanFactory beanFactory = (ConfigurableListableBeanFactory) testContext
.getApplicationContext()
.getAutowireCapableBeanFactory();
BeanFactoryInitializationAotContribution aotContribution = new MainMethodBeanFactoryInitializationAotProcessor()
.processAheadOfTime(beanFactory);
assertThat(aotContribution).isNotNull();
TestGenerationContext generationContext = new TestGenerationContext();
aotContribution.applyTo(generationContext, null);
RuntimeHints runtimeHints = generationContext.getRuntimeHints();
assertThat(RuntimeHintsPredicates.reflection().onMethod(ConfigWithMain.class, "main").invoke())
.accepts(runtimeHints);
}

@Test
void whenSubclassProvidesCustomApplicationContextFactory() {
TestContext testContext = new ExposedTestContextManager(CustomApplicationContextTest.class)
Expand Down

0 comments on commit e1f45c5

Please sign in to comment.