Skip to content

Commit 7f24798

Browse files
roaks3pengq-google
authored andcommitted
Add check for version guards in erb templates (GoogleCloudPlatform#10297)
1 parent 37012e8 commit 7f24798

File tree

4 files changed

+194
-0
lines changed

4 files changed

+194
-0
lines changed

tools/template-check/go.mod

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module github.com/GoogleCloudPlatform/magic-modules/tools/template-check
2+
3+
go 1.21

tools/template-check/main.go

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package main
2+
3+
import (
4+
"bufio"
5+
"flag"
6+
"fmt"
7+
"os"
8+
9+
"github.com/GoogleCloudPlatform/magic-modules/tools/template-check/ruby"
10+
)
11+
12+
func isValidTemplate(filename string) (bool, error) {
13+
results, err := ruby.CheckVersionGuardsForFile(filename)
14+
if err != nil {
15+
return false, err
16+
}
17+
18+
if len(results) > 0 {
19+
fmt.Fprintf(os.Stderr, "error: invalid version checks found in %s:\n", filename)
20+
for _, result := range results {
21+
fmt.Fprintf(os.Stderr, " %s\n", result)
22+
}
23+
return false, nil
24+
}
25+
26+
return true, nil
27+
}
28+
29+
func checkTemplate(filename string) bool {
30+
valid, err := isValidTemplate(filename)
31+
if err != nil {
32+
fmt.Fprintln(os.Stderr, err.Error())
33+
return false
34+
}
35+
return valid
36+
}
37+
38+
func main() {
39+
flag.Usage = func() {
40+
fmt.Fprintf(flag.CommandLine.Output(), "template-check - check that a template file is valid\n template-check [file]\n")
41+
}
42+
43+
flag.Parse()
44+
45+
// Handle file as a positional argument
46+
if flag.Arg(0) != "" {
47+
if !checkTemplate(flag.Arg(0)) {
48+
os.Exit(1)
49+
}
50+
os.Exit(0)
51+
}
52+
53+
// Handle files coming from a linux pipe
54+
fileInfo, _ := os.Stdin.Stat()
55+
if fileInfo.Mode()&os.ModeCharDevice == 0 {
56+
exitStatus := 0
57+
scanner := bufio.NewScanner(bufio.NewReader(os.Stdin))
58+
for scanner.Scan() {
59+
if !checkTemplate(scanner.Text()) {
60+
exitStatus = 1
61+
}
62+
}
63+
64+
os.Exit(exitStatus)
65+
}
66+
}

tools/template-check/ruby/ruby.go

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package ruby
2+
3+
import (
4+
"bufio"
5+
"fmt"
6+
"io"
7+
"os"
8+
"regexp"
9+
"strings"
10+
)
11+
12+
// Note: this is allowlisted to combat other issues like `=` instead of `==`, but it is possible we
13+
// need to add more options to this list in the future, like `private`. The main thing we want to
14+
// prevent is targeting `beta` in version guards, because it mishandles either `ga` or `private`.
15+
var allowedGuards = []string{
16+
"<% unless version == 'ga' -%>",
17+
"<% if version == 'ga' -%>",
18+
"<% unless version == \"ga\" -%>",
19+
"<% if version == \"ga\" -%>",
20+
}
21+
22+
// Note: this does not account for _every_ possible use of a version guard (for example, those
23+
// starting with `version.nil?`), because the logic would start to get more complicated. Instead,
24+
// the goal is to capture (and validate) all "standard" version guards that would be added for new
25+
// resources/fields.
26+
func isVersionGuard(line string) bool {
27+
re := regexp.MustCompile("<% [a-z]+ version ")
28+
return re.MatchString(line)
29+
}
30+
31+
func isValidVersionGuard(line string) bool {
32+
for _, g := range allowedGuards {
33+
if strings.Contains(line, g) {
34+
return true
35+
}
36+
}
37+
return false
38+
}
39+
40+
// CheckVersionGuards scans the input for version guards, and checks that those version guards are
41+
// valid. Invalid version guards are returned along with the line number in which they occurred.
42+
func CheckVersionGuards(r io.Reader) []string {
43+
scanner := bufio.NewScanner(r)
44+
lineNum := 1
45+
var invalidGuards []string
46+
for scanner.Scan() {
47+
if isVersionGuard(scanner.Text()) && !isValidVersionGuard(scanner.Text()) {
48+
invalidGuards = append(invalidGuards, fmt.Sprintf("%d: %s", lineNum, scanner.Text()))
49+
}
50+
lineNum++
51+
}
52+
return invalidGuards
53+
}
54+
55+
// CheckVersionGuardsForFile scans the file for version guards, and checks that those version
56+
// guards are valid. Invalid version guards are returned along with the line number in which they
57+
// occurred.
58+
func CheckVersionGuardsForFile(filename string) ([]string, error) {
59+
file, err := os.Open(filename)
60+
if err != nil {
61+
return nil, err
62+
}
63+
defer file.Close()
64+
65+
return CheckVersionGuards(file), nil
66+
}
+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package ruby
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestCheckVersionGuards(t *testing.T) {
9+
cases := map[string]struct {
10+
fileText string
11+
expectedResults []string
12+
}{
13+
"allow standard format targeting ga": {
14+
fileText: "some text\n<% unless version == 'ga' -%>\nmore text",
15+
expectedResults: nil,
16+
},
17+
"disallow targeting beta": {
18+
fileText: "some text\n<% unless version == 'beta' -%>\nmore text",
19+
expectedResults: []string{"2: <% unless version == 'beta' -%>"},
20+
},
21+
"disallow 'if not'": {
22+
fileText: "some text\n<% if version != 'ga' -%>\nmore text",
23+
expectedResults: []string{"2: <% if version != 'ga' -%>"},
24+
},
25+
"disallow single '='": {
26+
fileText: "some text\n<% unless version = 'ga' -%>\nmore text",
27+
expectedResults: []string{"2: <% unless version = 'ga' -%>"},
28+
},
29+
"disallow leaving trailing line break": {
30+
fileText: "some text\n<% unless version == 'ga' %>\nmore text",
31+
expectedResults: []string{"2: <% unless version == 'ga' %>"},
32+
},
33+
"one valid, one invalid": {
34+
fileText: "some text\n<% unless version == 'beta' -%>\nmore text\n<% unless version == 'ga' -%>",
35+
expectedResults: []string{"2: <% unless version == 'beta' -%>"},
36+
},
37+
"multiple invalid": {
38+
fileText: "some text\n<% unless version == 'beta' -%>\nmore text\n\n\n<% if version == \"beta\" -%>",
39+
expectedResults: []string{"2: <% unless version == 'beta' -%>", "6: <% if version == \"beta\" -%>"},
40+
},
41+
}
42+
43+
for tn, tc := range cases {
44+
tc := tc
45+
t.Run(tn, func(t *testing.T) {
46+
t.Parallel()
47+
results := CheckVersionGuards(strings.NewReader(tc.fileText))
48+
if len(results) != len(tc.expectedResults) {
49+
t.Errorf("Expected length %d, got %d", len(tc.expectedResults), len(results))
50+
return
51+
}
52+
for i, result := range results {
53+
if result != tc.expectedResults[i] {
54+
t.Errorf("Expected %v, got %v", tc.expectedResults[i], result)
55+
}
56+
}
57+
})
58+
}
59+
}

0 commit comments

Comments
 (0)