Skip to content

Commit c2fdf83

Browse files
committed
refactor template parsing logic
Move the logic for selecting a license template based on user input into a standalone func (fetchTemplate), and add test cases for all code paths. Delay parsing predefined license templates. This allows the new fetchTemplate method to modify these templates before returning in the future (to add SPDX license information). Add tests to ensure that these templates must always parse properly. Rename copyrightData type to licenseData, since we will soon begin to add more than just copyright data here (SPDX ID). Rename prefix func to executeTemplate, since this better describes what the function is doing. These are all refactoring and cleanup changes; no behavioral changes.
1 parent 6d92264 commit c2fdf83

File tree

4 files changed

+167
-47
lines changed

4 files changed

+167
-47
lines changed

main.go

+28-36
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ var (
5858
checkonly = flag.Bool("check", false, "check only mode: verify presence of license headers and exit with non-zero code if missing")
5959
)
6060

61+
func init() {
62+
flag.Usage = func() {
63+
fmt.Fprintln(os.Stderr, helpText)
64+
flag.PrintDefaults()
65+
}
66+
flag.Var(&skipExtensionFlags, "skip", "To skip files to check/add the header file, for example: -skip rb -skip go")
67+
}
68+
6169
type skipExtensionFlag []string
6270

6371
func (i *skipExtensionFlag) String() string {
@@ -70,40 +78,24 @@ func (i *skipExtensionFlag) Set(value string) error {
7078
}
7179

7280
func main() {
73-
flag.Usage = func() {
74-
fmt.Fprintln(os.Stderr, helpText)
75-
flag.PrintDefaults()
76-
}
77-
flag.Var(&skipExtensionFlags, "skip", "To skip files to check/add the header file, for example: -skip rb -skip go")
7881
flag.Parse()
7982
if flag.NArg() == 0 {
8083
flag.Usage()
8184
os.Exit(1)
8285
}
8386

84-
data := &copyrightData{
87+
data := licenseData{
8588
Year: *year,
8689
Holder: *holder,
8790
}
8891

89-
var t *template.Template
90-
if *licensef != "" {
91-
d, err := ioutil.ReadFile(*licensef)
92-
if err != nil {
93-
log.Printf("license file: %v", err)
94-
os.Exit(1)
95-
}
96-
t, err = template.New("").Parse(string(d))
97-
if err != nil {
98-
log.Printf("license file: %v", err)
99-
os.Exit(1)
100-
}
101-
} else {
102-
t = licenseTemplate[*license]
103-
if t == nil {
104-
log.Printf("unknown license: %s", *license)
105-
os.Exit(1)
106-
}
92+
tpl, err := fetchTemplate(*license, *licensef)
93+
if err != nil {
94+
log.Fatal(err)
95+
}
96+
t, err := template.New("").Parse(tpl)
97+
if err != nil {
98+
log.Fatal(err)
10799
}
108100

109101
// process at most 1000 files in parallel
@@ -189,7 +181,7 @@ func walk(ch chan<- *file, start string) {
189181
// addLicense add a license to the file if missing.
190182
//
191183
// It returns true if the file was updated.
192-
func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) (bool, error) {
184+
func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data licenseData) (bool, error) {
193185
var lic []byte
194186
var err error
195187
lic, err = licenseHeader(path, tmpl, data)
@@ -227,32 +219,32 @@ func fileHasLicense(path string) (bool, error) {
227219
return hasLicense(b) || isGenerated(b), nil
228220
}
229221

230-
func licenseHeader(path string, tmpl *template.Template, data *copyrightData) ([]byte, error) {
222+
func licenseHeader(path string, tmpl *template.Template, data licenseData) ([]byte, error) {
231223
var lic []byte
232224
var err error
233225
switch fileExtension(path) {
234226
default:
235227
return nil, nil
236228
case ".c", ".h", ".gv":
237-
lic, err = prefix(tmpl, data, "/*", " * ", " */")
229+
lic, err = executeTemplate(tmpl, data, "/*", " * ", " */")
238230
case ".js", ".mjs", ".cjs", ".jsx", ".tsx", ".css", ".scss", ".sass", ".tf", ".ts":
239-
lic, err = prefix(tmpl, data, "/**", " * ", " */")
231+
lic, err = executeTemplate(tmpl, data, "/**", " * ", " */")
240232
case ".cc", ".cpp", ".cs", ".go", ".hh", ".hpp", ".java", ".m", ".mm", ".proto", ".rs", ".scala", ".swift", ".dart", ".groovy", ".kt", ".kts", ".v", ".sv":
241-
lic, err = prefix(tmpl, data, "", "// ", "")
233+
lic, err = executeTemplate(tmpl, data, "", "// ", "")
242234
case ".py", ".sh", ".yaml", ".yml", ".dockerfile", "dockerfile", ".rb", "gemfile", ".tcl", ".bzl":
243-
lic, err = prefix(tmpl, data, "", "# ", "")
235+
lic, err = executeTemplate(tmpl, data, "", "# ", "")
244236
case ".el", ".lisp":
245-
lic, err = prefix(tmpl, data, "", ";; ", "")
237+
lic, err = executeTemplate(tmpl, data, "", ";; ", "")
246238
case ".erl":
247-
lic, err = prefix(tmpl, data, "", "% ", "")
239+
lic, err = executeTemplate(tmpl, data, "", "% ", "")
248240
case ".hs", ".sql", ".sdl":
249-
lic, err = prefix(tmpl, data, "", "-- ", "")
241+
lic, err = executeTemplate(tmpl, data, "", "-- ", "")
250242
case ".html", ".xml", ".vue":
251-
lic, err = prefix(tmpl, data, "<!--", " ", "-->")
243+
lic, err = executeTemplate(tmpl, data, "<!--", " ", "-->")
252244
case ".php":
253-
lic, err = prefix(tmpl, data, "", "// ", "")
245+
lic, err = executeTemplate(tmpl, data, "", "// ", "")
254246
case ".ml", ".mli", ".mll", ".mly":
255-
lic, err = prefix(tmpl, data, "(**", " ", "*)")
247+
lic, err = executeTemplate(tmpl, data, "(**", " ", "*)")
256248
}
257249
return lic, err
258250
}

testdata/custom.tpl

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Copyright {{.Year}} {{.Holder}}
2+
3+
Custom License Template

tmpl.go

+34-11
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,50 @@ import (
1919
"bytes"
2020
"fmt"
2121
"html/template"
22+
"io/ioutil"
2223
"strings"
2324
"unicode"
2425
)
2526

26-
var licenseTemplate = make(map[string]*template.Template)
27+
var licenseTemplate = map[string]string{
28+
"apache": tmplApache,
29+
"mit": tmplMIT,
30+
"bsd": tmplBSD,
31+
"mpl": tmplMPL,
32+
}
2733

28-
func init() {
29-
licenseTemplate["apache"] = template.Must(template.New("").Parse(tmplApache))
30-
licenseTemplate["mit"] = template.Must(template.New("").Parse(tmplMIT))
31-
licenseTemplate["bsd"] = template.Must(template.New("").Parse(tmplBSD))
32-
licenseTemplate["mpl"] = template.Must(template.New("").Parse(tmplMPL))
34+
// licenseData specifies the data used to fill out a license template.
35+
type licenseData struct {
36+
Year string // Copyright year(s).
37+
Holder string // Name of the copyright holder.
3338
}
3439

35-
type copyrightData struct {
36-
Year string
37-
Holder string
40+
// fetchTemplate returns the license template for the specified license and
41+
// optional templateFile. If templateFile is provided, the license is read
42+
// from the specified file. Otherwise, a template is loaded for the specified
43+
// license, if recognized.
44+
func fetchTemplate(license string, templateFile string) (string, error) {
45+
var t string
46+
if templateFile != "" {
47+
d, err := ioutil.ReadFile(templateFile)
48+
if err != nil {
49+
return "", fmt.Errorf("license file: %w", err)
50+
}
51+
52+
t = string(d)
53+
} else {
54+
t = licenseTemplate[license]
55+
if t == "" {
56+
return "", fmt.Errorf("unknown license: %q", license)
57+
}
58+
}
59+
60+
return t, nil
3861
}
3962

40-
// prefix will execute a license template t with data d
63+
// executeTemplate will execute a license template t with data d
4164
// and prefix the result with top, middle and bottom.
42-
func prefix(t *template.Template, d *copyrightData, top, mid, bot string) ([]byte, error) {
65+
func executeTemplate(t *template.Template, d licenseData, top, mid, bot string) ([]byte, error) {
4366
var buf bytes.Buffer
4467
if err := t.Execute(&buf, d); err != nil {
4568
return nil, err

tmpl_test.go

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright 2018 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"errors"
19+
"html/template"
20+
"os"
21+
"testing"
22+
)
23+
24+
func init() {
25+
// ensure that pre-defined templates must parse
26+
template.Must(template.New("").Parse(tmplApache))
27+
template.Must(template.New("").Parse(tmplMIT))
28+
template.Must(template.New("").Parse(tmplBSD))
29+
template.Must(template.New("").Parse(tmplMPL))
30+
}
31+
32+
func TestFetchTemplate(t *testing.T) {
33+
tests := []struct {
34+
description string // test case description
35+
license string // license passed to fetchTemplate
36+
templateFile string // templatefile passed to fetchTemplate
37+
wantTemplate string // expected returned template
38+
wantErr error // expected returned error
39+
}{
40+
{
41+
"non-existant template file",
42+
"",
43+
"/does/not/exist",
44+
"",
45+
os.ErrNotExist,
46+
},
47+
{
48+
"custom template file",
49+
"",
50+
"testdata/custom.tpl",
51+
"Copyright {{.Year}} {{.Holder}}\n\nCustom License Template\n",
52+
nil,
53+
},
54+
{
55+
"unknown license",
56+
"unknown",
57+
"",
58+
"",
59+
errors.New(`unknown license: "unknown"`),
60+
},
61+
{
62+
"apache license template",
63+
"apache",
64+
"",
65+
tmplApache,
66+
nil,
67+
},
68+
{
69+
"mit license template",
70+
"mit",
71+
"",
72+
tmplMIT,
73+
nil,
74+
},
75+
{
76+
"bsd license template",
77+
"bsd",
78+
"",
79+
tmplBSD,
80+
nil,
81+
},
82+
{
83+
"mpl license template",
84+
"mpl",
85+
"",
86+
tmplMPL,
87+
nil,
88+
},
89+
}
90+
91+
for _, tt := range tests {
92+
t.Run(tt.description, func(t *testing.T) {
93+
tpl, err := fetchTemplate(tt.license, tt.templateFile)
94+
if tt.wantErr != nil && (err == nil || (!errors.Is(err, tt.wantErr) && err.Error() != tt.wantErr.Error())) {
95+
t.Fatalf("fetchTemplate(%q, %q) returned error: %#v, want %#v", tt.license, tt.templateFile, err, tt.wantErr)
96+
}
97+
if tpl != tt.wantTemplate {
98+
t.Errorf("fetchTemplate(%q, %q) returned template: %q, want %q", tt.license, tt.templateFile, tpl, tt.wantTemplate)
99+
}
100+
})
101+
}
102+
}

0 commit comments

Comments
 (0)