Skip to content

Commit 27146d5

Browse files
committed
Implement -check flag
When this flag is used: * The program never modifies any files * If all files in the pattern contain a license, the program exits with a zero exit code * If at least one file in the pattern requires modification to include license text, the program prints such files to STDOUT and exits with a non-zero exit code
1 parent c464135 commit 27146d5

File tree

3 files changed

+118
-32
lines changed

3 files changed

+118
-32
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ to any file that already has one.
1818
-f custom license file (no default)
1919
-l license type: apache, bsd, mit (defaults to "apache")
2020
-y year (defaults to current year)
21+
-check check only mode: verify presence of license headers and exit with non-zero code if missing
2122

2223
The pattern argument can be provided multiple times, and may also refer
2324
to single files.

main.go

+72-32
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package main
1818

1919
import (
2020
"bytes"
21+
"errors"
2122
"flag"
2223
"fmt"
2324
"html/template"
@@ -46,11 +47,12 @@ Flags:
4647
`
4748

4849
var (
49-
holder = flag.String("c", "Google LLC", "copyright holder")
50-
license = flag.String("l", "apache", "license type: apache, bsd, mit")
51-
licensef = flag.String("f", "", "license file")
52-
year = flag.String("y", fmt.Sprint(time.Now().Year()), "copyright year(s)")
53-
verbose = flag.Bool("v", false, "verbose mode: print the name of the files that are modified")
50+
holder = flag.String("c", "Google LLC", "copyright holder")
51+
license = flag.String("l", "apache", "license type: apache, bsd, mit")
52+
licensef = flag.String("f", "", "license file")
53+
year = flag.String("y", fmt.Sprint(time.Now().Year()), "copyright year(s)")
54+
verbose = flag.Bool("v", false, "verbose mode: print the name of the files that are modified")
55+
checkonly = flag.Bool("check", false, "check only mode: verify presence of license headers and exit with non-zero code if missing")
5456
)
5557

5658
func main() {
@@ -97,13 +99,35 @@ func main() {
9799
for f := range ch {
98100
f := f // https://golang.org/doc/faq#closures_and_goroutines
99101
wg.Go(func() error {
100-
modified, err := addLicense(f.path, f.mode, t, data)
101-
if err != nil {
102-
log.Printf("%s: %v", f.path, err)
103-
return err
104-
}
105-
if *verbose && modified {
106-
log.Printf("%s modified", f.path)
102+
if *checkonly {
103+
// Check if file extension is known
104+
lic, err := licenseHeader(f.path, t, data)
105+
if err != nil {
106+
log.Printf("%s: %v", f.path, err)
107+
return err
108+
}
109+
if lic == nil { // Unknown fileExtension
110+
return nil
111+
}
112+
// Check if file has a license
113+
isMissingLicenseHeader, err := fileHasLicense(f.path)
114+
if err != nil {
115+
log.Printf("%s: %v", f.path, err)
116+
return err
117+
}
118+
if isMissingLicenseHeader {
119+
fmt.Printf("%s\n", f.path)
120+
return errors.New("missing license header")
121+
}
122+
} else {
123+
modified, err := addLicense(f.path, f.mode, t, data)
124+
if err != nil {
125+
log.Printf("%s: %v", f.path, err)
126+
return err
127+
}
128+
if *verbose && modified {
129+
log.Printf("%s modified", f.path)
130+
}
107131
}
108132
return nil
109133
})
@@ -142,11 +166,45 @@ func walk(ch chan<- *file, start string) {
142166
}
143167

144168
func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) (bool, error) {
169+
var lic []byte
170+
var err error
171+
lic, err = licenseHeader(path, tmpl, data)
172+
if err != nil || lic == nil {
173+
return false, err
174+
}
175+
176+
b, err := ioutil.ReadFile(path)
177+
if err != nil || hasLicense(b) {
178+
return false, err
179+
}
180+
181+
line := hashBang(b)
182+
if len(line) > 0 {
183+
b = b[len(line):]
184+
if line[len(line)-1] != '\n' {
185+
line = append(line, '\n')
186+
}
187+
lic = append(line, lic...)
188+
}
189+
b = append(lic, b...)
190+
return true, ioutil.WriteFile(path, b, fmode)
191+
}
192+
193+
// fileHasLicense reports whether the file at path contains a license header.
194+
func fileHasLicense(path string) (bool, error) {
195+
b, err := ioutil.ReadFile(path)
196+
if err != nil || hasLicense(b) {
197+
return false, err
198+
}
199+
return true, nil
200+
}
201+
202+
func licenseHeader(path string, tmpl *template.Template, data *copyrightData) ([]byte, error) {
145203
var lic []byte
146204
var err error
147205
switch fileExtension(path) {
148206
default:
149-
return false, nil
207+
return nil, nil
150208
case ".c", ".h":
151209
lic, err = prefix(tmpl, data, "/*", " * ", " */")
152210
case ".js", ".jsx", ".tsx", ".css", ".tf", ".ts":
@@ -168,25 +226,7 @@ func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *c
168226
case ".ml", ".mli", ".mll", ".mly":
169227
lic, err = prefix(tmpl, data, "(**", " ", "*)")
170228
}
171-
if err != nil || lic == nil {
172-
return false, err
173-
}
174-
175-
b, err := ioutil.ReadFile(path)
176-
if err != nil || hasLicense(b) {
177-
return false, err
178-
}
179-
180-
line := hashBang(b)
181-
if len(line) > 0 {
182-
b = b[len(line):]
183-
if line[len(line)-1] != '\n' {
184-
line = append(line, '\n')
185-
}
186-
lic = append(line, lic...)
187-
}
188-
b = append(lic, b...)
189-
return true, ioutil.WriteFile(path, b, fmode)
229+
return lic, err
190230
}
191231

192232
func fileExtension(name string) string {

main_test.go

+45
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,48 @@ func TestReadErrors(t *testing.T) {
139139
}
140140
run(t, "chmod", "0644", samplefile)
141141
}
142+
143+
func TestCheckSuccess(t *testing.T) {
144+
if os.Getenv("RUNME") != "" {
145+
main()
146+
return
147+
}
148+
149+
tmp := tempDir(t)
150+
t.Logf("tmp dir: %s", tmp)
151+
samplefile := filepath.Join(tmp, "file.c")
152+
153+
run(t, "cp", "testdata/expected/file.c", samplefile)
154+
cmd := exec.Command(os.Args[0],
155+
"-test.run=TestCheckSuccess",
156+
"-l", "apache", "-c", "Google LLC", "-y", "2018",
157+
"-check", samplefile,
158+
)
159+
cmd.Env = []string{"RUNME=1"}
160+
if out, err := cmd.CombinedOutput(); err != nil {
161+
t.Fatalf("%v\n%s", err, out)
162+
}
163+
}
164+
165+
func TestCheckFail(t *testing.T) {
166+
if os.Getenv("RUNME") != "" {
167+
main()
168+
return
169+
}
170+
171+
tmp := tempDir(t)
172+
t.Logf("tmp dir: %s", tmp)
173+
samplefile := filepath.Join(tmp, "file.c")
174+
175+
run(t, "cp", "testdata/initial/file.c", samplefile)
176+
cmd := exec.Command(os.Args[0],
177+
"-test.run=TestCheckFail",
178+
"-l", "apache", "-c", "Google LLC", "-y", "2018",
179+
"-check", samplefile,
180+
)
181+
cmd.Env = []string{"RUNME=1"}
182+
out, err := cmd.CombinedOutput()
183+
if err == nil {
184+
t.Fatalf("TestCheckFail exited with a zero exit code.\n%s", out)
185+
}
186+
}

0 commit comments

Comments
 (0)