Skip to content

Commit 17f9973

Browse files
committed
Allow setting wrapper targets based on annotations
This allows for registering that, e.g. ```java @RequiresNetwork class SomeTest {} ``` should be generated as: ```starlark requires_network( java_test, ... ) ``` instead of: ```starlark java_test( ... ) ``` In suite mode, separate targets will be generated for these special targets.
1 parent 8b833d8 commit 17f9973

File tree

15 files changed

+581
-19
lines changed

15 files changed

+581
-19
lines changed

java/gazelle/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ go_library(
2020
"//java/gazelle/private/javaparser",
2121
"//java/gazelle/private/logconfig",
2222
"//java/gazelle/private/maven",
23+
"//java/gazelle/private/sorted_multiset",
2324
"//java/gazelle/private/sorted_set",
2425
"//java/gazelle/private/types",
2526
"@bazel_gazelle//config:go_default_library",

java/gazelle/configure.go

+50
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,21 @@ import (
2020
type Configurer struct {
2121
lang *javaLang
2222
annotationToAttribute annotationToAttribute
23+
annotationToWrapper annotationToWrapper
2324
mavenInstallFile string
2425
}
2526

2627
func NewConfigurer(lang *javaLang) *Configurer {
2728
return &Configurer{
2829
lang: lang,
2930
annotationToAttribute: make(annotationToAttribute),
31+
annotationToWrapper: make(annotationToWrapper),
3032
}
3133
}
3234

3335
func (jc *Configurer) RegisterFlags(fs *flag.FlagSet, cmd string, c *config.Config) {
3436
fs.Var(&jc.annotationToAttribute, "java-annotation-to-attribute", "Mapping of annotations (on test classes) to attributes which should be set for that test rule. Examples: com.example.annotations.FlakyTest=flaky=True com.example.annotations.SlowTest=timeout=\"long\"")
37+
fs.Var(&jc.annotationToWrapper, "java-annotation-to-wrapper", "Mapping of annotations (on test classes) to wrapper rules which should be used around the test rule. Example: com.example.annotations.RequiresNetwork=@some//wrapper:file.bzl=requires_network")
3538
fs.StringVar(&jc.mavenInstallFile, "java-maven-install-file", "", "Path of the maven_install.json file. Defaults to \"maven_install.json\".")
3639
}
3740

@@ -42,6 +45,9 @@ func (jc *Configurer) CheckFlags(fs *flag.FlagSet, c *config.Config) error {
4245
cfgs[""].MapAnnotationToAttribute(annotation, k, v)
4346
}
4447
}
48+
for annotation, wrapper := range jc.annotationToWrapper {
49+
cfgs[""].MapAnnotationToWrapper(annotation, wrapper.symbol)
50+
}
4551
if jc.mavenInstallFile != "" {
4652
cfgs[""].SetMavenInstallFile(jc.mavenInstallFile)
4753
}
@@ -192,3 +198,47 @@ func (f *annotationToAttribute) Set(value string) error {
192198
(*f)[annotationClassName][key] = parsedValue
193199
return nil
194200
}
201+
202+
type loadInfo struct {
203+
from string
204+
symbol string
205+
}
206+
207+
type annotationToWrapper map[string]loadInfo
208+
209+
func (f *annotationToWrapper) String() string {
210+
s := "annotationToWrapper{"
211+
for a, li := range *f {
212+
s += a + ": "
213+
s += fmt.Sprintf(`load("%s", "%s")`, li.from, li.symbol)
214+
}
215+
s += "}"
216+
return s
217+
}
218+
219+
func (f *annotationToWrapper) Set(value string) error {
220+
parts := strings.Split(value, "=")
221+
if len(parts) != 2 {
222+
return fmt.Errorf("want --java-annotation-to-wrapper to have format com.example.RequiresNetwork=@some_repo//has:wrapper.bzl,wrapper_rule but didn't see exactly one equals sign")
223+
}
224+
annotation := parts[0]
225+
226+
if _, ok := (*f)[annotation]; ok {
227+
return fmt.Errorf("saw conflicting values for --java-annotation-to-wrapper flag for annotation %v", annotation)
228+
}
229+
230+
vParts := strings.Split(parts[1], ",")
231+
if len(vParts) != 2 {
232+
return fmt.Errorf("want --java-annotation-to-wrapper to have format com.example.RequiresNetwork=@some_repo//has:wrapper.bzl,wrapper_rule but didn't see exactly one comma after equals sign")
233+
}
234+
235+
from := vParts[0]
236+
symbol := vParts[1]
237+
238+
(*f)[annotation] = loadInfo{
239+
from: from,
240+
symbol: symbol,
241+
}
242+
243+
return nil
244+
}

java/gazelle/generate.go

+37-11
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ func javaFileLess(l, r javaFile) bool {
3636
return l.pathRelativeToBazelWorkspaceRoot < r.pathRelativeToBazelWorkspaceRoot
3737
}
3838

39+
type separateJavaTestReasons struct {
40+
attributes map[string]bzl.Expr
41+
wrapper string
42+
}
43+
3944
// GenerateRules extracts build metadata from source files in a directory.
4045
//
4146
// See language.GenerateRules for more information.
@@ -114,7 +119,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
114119
testJavaImports := sorted_set.NewSortedSetFn([]types.PackageName{}, types.PackageNameLess)
115120

116121
// Java Test files which need to be generated separately from any others because they have explicit attribute overrides.
117-
separateTestJavaFiles := make(map[javaFile]map[string]bzl.Expr)
122+
separateTestJavaFiles := make(map[javaFile]separateJavaTestReasons)
118123

119124
// Files which are used by non-test classes in test java packages.
120125
testHelperJavaFiles := sorted_set.NewSortedSetFn([]javaFile{}, javaFileLess)
@@ -234,8 +239,8 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
234239
switch cfg.TestMode() {
235240
case "file":
236241
for _, tf := range testJavaFiles.SortedSlice() {
237-
extraAttributes := separateTestJavaFiles[tf]
238-
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, nil, extraAttributes, &res)
242+
separateJavaTestReasons := separateTestJavaFiles[tf]
243+
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, nil, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
239244
}
240245

241246
case "suite":
@@ -278,7 +283,8 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
278283
if testHelperJavaFiles.Len() > 0 {
279284
testHelperDep = ptr(testHelperLibname(suiteName))
280285
}
281-
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, testHelperDep, separateTestJavaFiles[src], &res)
286+
separateJavaTestReasons := separateTestJavaFiles[src]
287+
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, testHelperDep, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
282288
}
283289
}
284290
}
@@ -407,10 +413,11 @@ func addFilteringOutOwnPackage(to *sorted_set.SortedSet[types.PackageName], from
407413
}
408414
}
409415

410-
func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFiles *sorted_set.SortedSet[javaFile], separateTestJavaFiles map[javaFile]map[string]bzl.Expr, file javaFile, perClassMetadata map[string]java.PerClassMetadata, log zerolog.Logger) {
416+
func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFiles *sorted_set.SortedSet[javaFile], separateTestJavaFiles map[javaFile]separateJavaTestReasons, file javaFile, perClassMetadata map[string]java.PerClassMetadata, log zerolog.Logger) {
411417
if cfg.IsJavaTestFile(filepath.Base(file.pathRelativeToBazelWorkspaceRoot)) {
412418
annotationClassNames := perClassMetadata[file.ClassName().FullyQualifiedClassName()].AnnotationClassNames
413419
perFileAttrs := make(map[string]bzl.Expr)
420+
wrapper := ""
414421
for _, annotationClassName := range annotationClassNames.SortedSlice() {
415422
if attrs, ok := cfg.AttributesForAnnotation(annotationClassName); ok {
416423
for k, v := range attrs {
@@ -420,10 +427,20 @@ func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFil
420427
perFileAttrs[k] = v
421428
}
422429
}
430+
newWrapper, ok := cfg.WrapperForAnnotation(annotationClassName)
431+
if ok {
432+
if wrapper != "" {
433+
log.Error().Str("file", file.pathRelativeToBazelWorkspaceRoot).Msgf("Saw conflicting wrappers from annotations: %v and %v. Picking one at random.", wrapper, newWrapper)
434+
}
435+
wrapper = newWrapper
436+
}
423437
}
424438
testJavaFiles.Add(file)
425-
if len(perFileAttrs) > 0 {
426-
separateTestJavaFiles[file] = perFileAttrs
439+
if len(perFileAttrs) > 0 || wrapper != "" {
440+
separateTestJavaFiles[file] = separateJavaTestReasons{
441+
attributes: perFileAttrs,
442+
wrapper: wrapper,
443+
}
427444
}
428445
} else {
429446
testHelperJavaFiles.Add(file)
@@ -488,7 +505,7 @@ func (l javaLang) generateJavaBinary(file *rule.File, m types.ClassName, libName
488505
})
489506
}
490507

491-
func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], depOnTestHelpers *string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
508+
func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], depOnTestHelpers *string, wrapper string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
492509
className := f.ClassName()
493510
fullyQualifiedTestClass := className.FullyQualifiedClassName()
494511
var testName string
@@ -498,12 +515,12 @@ func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazel
498515
testName = className.BareOuterClassName()
499516
}
500517

501-
ruleKind := "java_test"
518+
javaRuleKind := "java_test"
502519
if importsJunit5(imports) {
503-
ruleKind = "java_junit5_test"
520+
javaRuleKind = "java_junit5_test"
504521
}
505522

506-
runtimeDeps := l.collectRuntimeDeps(ruleKind, testName, file)
523+
runtimeDeps := l.collectRuntimeDeps(javaRuleKind, testName, file)
507524
if importsJunit5(imports) {
508525
// This should probably register imports here, and then allow the
509526
// resolver to resolve this to an artifact, but we don't currently wire
@@ -514,7 +531,16 @@ func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazel
514531
}
515532
}
516533

534+
ruleKind := javaRuleKind
535+
if wrapper != "" {
536+
ruleKind = wrapper
537+
}
538+
517539
r := rule.NewRule(ruleKind, testName)
540+
if wrapper != "" {
541+
r.AddArg(&bzl.Ident{Name: javaRuleKind})
542+
}
543+
518544
path := strings.TrimPrefix(f.pathRelativeToBazelWorkspaceRoot, pathToPackageRelativeToBazelWorkspace+"/")
519545
r.SetAttr("srcs", []string{path})
520546
r.SetAttr("test_class", fullyQualifiedTestClass)

java/gazelle/generate_test.go

+27-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/sorted_set"
77
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/types"
88
"github.com/bazelbuild/bazel-gazelle/language"
9+
bzl "github.com/bazelbuild/buildtools/build"
910
"github.com/google/go-cmp/cmp"
1011
"github.com/rs/zerolog"
1112
"github.com/stretchr/testify/require"
@@ -19,10 +20,12 @@ func TestSingleJavaTestFile(t *testing.T) {
1920
type testCase struct {
2021
includePackageInName bool
2122
importedPackages []string
23+
wrapper string
2224
wantRuleKind string
2325
wantImports []string
2426
wantDeps []string
2527
wantRuntimeDeps []string
28+
wantArgs []bzl.Expr
2629
}
2730

2831
for name, tc := range map[string]testCase{
@@ -86,6 +89,14 @@ func TestSingleJavaTestFile(t *testing.T) {
8689
wantRuleKind: "java_test",
8790
wantImports: []string{"com.example", "org.junit"},
8891
},
92+
"wrapper junit4": {
93+
includePackageInName: false,
94+
importedPackages: []string{"org.junit"},
95+
wrapper: "some_wrapper",
96+
wantRuleKind: "some_wrapper",
97+
wantImports: []string{"com.example", "org.junit"},
98+
wantArgs: []bzl.Expr{&bzl.Ident{Name: "java_test"}},
99+
},
89100
"explicit junit5": {
90101
includePackageInName: false,
91102
importedPackages: []string{"org.junit.jupiter.api"},
@@ -119,6 +130,19 @@ func TestSingleJavaTestFile(t *testing.T) {
119130
"@maven//:org_junit_platform_junit_platform_reporting",
120131
},
121132
},
133+
"wrapper junit5": {
134+
includePackageInName: false,
135+
importedPackages: []string{"org.junit.jupiter.api"},
136+
wrapper: "some_wrapper",
137+
wantRuleKind: "some_wrapper",
138+
wantImports: []string{"com.example", "org.junit.jupiter.api"},
139+
wantRuntimeDeps: []string{
140+
"@maven//:org_junit_jupiter_junit_jupiter_engine",
141+
"@maven//:org_junit_platform_junit_platform_launcher",
142+
"@maven//:org_junit_platform_junit_platform_reporting",
143+
},
144+
wantArgs: []bzl.Expr{&bzl.Ident{Name: "java_junit5_test"}},
145+
},
122146
"explicit both junit4 and junit5": {
123147
includePackageInName: false,
124148
importedPackages: []string{"org.junit", "org.junit.jupiter.api"},
@@ -135,7 +159,7 @@ func TestSingleJavaTestFile(t *testing.T) {
135159
var res language.GenerateResult
136160

137161
l := newTestJavaLang(t)
138-
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, nil, &res)
162+
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, tc.wrapper, nil, &res)
139163

140164
require.Len(t, res.Gen, 1, "want 1 generated rule")
141165

@@ -154,6 +178,7 @@ func TestSingleJavaTestFile(t *testing.T) {
154178
wantAttrs = append(wantAttrs, "runtime_deps")
155179
}
156180
require.ElementsMatch(t, wantAttrs, rule.AttrKeys())
181+
require.ElementsMatch(t, tc.wantArgs, rule.Args())
157182

158183
require.Len(t, res.Imports, 1, "want 1 generated importedPackages")
159184
wantImports := sorted_set.NewSortedSetFn([]types.PackageName{}, types.PackageNameLess)
@@ -165,6 +190,7 @@ func TestSingleJavaTestFile(t *testing.T) {
165190
if len(tc.wantRuntimeDeps) > 0 {
166191
require.ElementsMatch(t, tc.wantRuntimeDeps, rule.AttrStrings("runtime_deps"))
167192
}
193+
168194
})
169195
}
170196
}

java/gazelle/javaconfig/config.go

+24
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ func (c *Config) NewChild() *Config {
7070
testMode: c.testMode,
7171
customTestFileSuffixes: c.customTestFileSuffixes,
7272
annotationToAttribute: c.annotationToAttribute,
73+
annotationToWrapper: c.annotationToWrapper,
7374
excludedArtifacts: clonedExcludedArtifacts,
7475
mavenRepositoryName: c.mavenRepositoryName,
7576
}
@@ -99,6 +100,7 @@ type Config struct {
99100
customTestFileSuffixes *[]string
100101
excludedArtifacts map[string]struct{}
101102
annotationToAttribute map[string]map[string]bzl.Expr
103+
annotationToWrapper map[string]string
102104
mavenRepositoryName string
103105
}
104106

@@ -120,6 +122,7 @@ func New(repoRoot string) *Config {
120122
customTestFileSuffixes: nil,
121123
excludedArtifacts: make(map[string]struct{}),
122124
annotationToAttribute: make(map[string]map[string]bzl.Expr),
125+
annotationToWrapper: make(map[string]string),
123126
mavenRepositoryName: "maven",
124127
}
125128
}
@@ -244,6 +247,27 @@ func (c *Config) AttributesForAnnotation(annotation string) (map[string]bzl.Expr
244247
return m, ok
245248
}
246249

250+
func (c *Config) MapAnnotationToWrapper(annotation string, wrapper string) {
251+
c.annotationToWrapper[annotation] = wrapper
252+
}
253+
254+
func (c *Config) WrapperForAnnotation(annotation string) (string, bool) {
255+
s, ok := c.annotationToWrapper[annotation]
256+
return s, ok
257+
}
258+
259+
func (c *Config) IsTestRule(ruleKind string) bool {
260+
if ruleKind == "java_junit5_test" || ruleKind == "java_test" || ruleKind == "java_test_suite" {
261+
return true
262+
}
263+
for _, wrapper := range c.annotationToWrapper {
264+
if ruleKind == wrapper {
265+
return true
266+
}
267+
}
268+
return false
269+
}
270+
247271
func equalStringSlices(l, r []string) bool {
248272
if len(l) != len(r) {
249273
return false

0 commit comments

Comments
 (0)